Commit f409fbf0 authored by Yuxin Wu's avatar Yuxin Wu

Move gradient average utilities

parent d53b5c4c
...@@ -20,7 +20,7 @@ def global_import(name): ...@@ -20,7 +20,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
_SKIP = [] _SKIP = ['utils']
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
[_CURR_DIR]): [_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
...@@ -100,11 +100,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -100,11 +100,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
new_tower_grads = [] new_tower_grads = []
with tf.name_scope('AvgGrad'): with tf.name_scope('AvgGrad'):
for i, grad_and_vars in enumerate(zip(*tower_grads)): for i, grad_and_vars in enumerate(zip(*tower_grads)):
# Ngpu * 2 v = grad_and_vars[0][1] # Ngpu * 2
all_grads = [g for (g, _) in grad_and_vars]
with tf.device(devices[i % nr_device]): with tf.device(devices[i % nr_device]):
v = grad_and_vars[0][1]
# average gradient
all_grads = [g for (g, _) in grad_and_vars]
grad = tf.multiply( grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_device) tf.add_n(all_grads), 1.0 / nr_device)
new_tower_grads.append((grad, v)) new_tower_grads.append((grad, v))
......
...@@ -14,7 +14,9 @@ from ..tfutils.collection import backup_collection, restore_collection ...@@ -14,7 +14,9 @@ from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from .utils import LeastLoadedDeviceSetter, override_to_local_variable from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, average_grads)
__all__ = ['GraphBuilder', __all__ = ['GraphBuilder',
...@@ -123,25 +125,6 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -123,25 +125,6 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
assert ps_device in ['cpu', 'gpu'] assert ps_device in ['cpu', 'gpu']
self.ps_device = ps_device self.ps_device = ps_device
@staticmethod
def _average_grads(tower_grads):
# tower_grads: Ngpu x Nvar x 2
nr_tower = len(tower_grads)
if nr_tower == 1:
return tower_grads[0]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
# Ngpu * 2
v = grad_and_vars[0][1]
all_grads = [g for (g, _) in grad_and_vars]
with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_tower)
new_tower_grads.append((grad, v))
return new_tower_grads
def build(self, get_grad_fn, get_opt_fn): def build(self, get_grad_fn, get_opt_fn):
""" """
Args: Args:
...@@ -166,7 +149,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -166,7 +149,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
# self.train_op = tf.group(*ops) # self.train_op = tf.group(*ops)
# return # return
grads = SyncMultiGPUParameterServerBuilder._average_grads(grad_list) grads = average_grads(grad_list)
# grads = grad_list[0] # grads = grad_list[0]
opt = get_opt_fn() opt = get_opt_fn()
...@@ -187,27 +170,6 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -187,27 +170,6 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
See https://www.tensorflow.org/performance/benchmarks for details. See https://www.tensorflow.org/performance/benchmarks for details.
""" """
@staticmethod
def _allreduce_grads(tower_grads):
from tensorflow.contrib import nccl
nr_tower = len(tower_grads)
if nr_tower == 1:
return [[x] for x in tower_grads[0]]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
grads = [g for g, _ in grad_and_vars]
summed = nccl.all_sum(grads)
grads_for_a_var = []
for (_, v), g in zip(grad_and_vars, summed):
with tf.device(g.device):
g = tf.multiply(g, 1.0 / nr_tower)
grads_for_a_var.append((g, v))
new_tower_grads.append(grads_for_a_var)
# NVar * NGPU * 2
return new_tower_grads
def build(self, get_grad_fn, get_opt_fn): def build(self, get_grad_fn, get_opt_fn):
""" """
...@@ -231,13 +193,14 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -231,13 +193,14 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
get_grad_fn, get_grad_fn,
# use no variable scope for the first tower # use no variable scope for the first tower
use_vs=[False] + [True] * (len(self.towers) - 1)) use_vs=[False] + [True] * (len(self.towers) - 1))
grads = SyncMultiGPUReplicatedBuilder._allreduce_grads(grad_list)
DataParallelBuilder._check_grad_list(grad_list)
grads = allreduce_grads(grad_list)
train_ops = [] train_ops = []
opt = get_opt_fn() opt = get_opt_fn()
for idx in range(len(self.towers)): for idx, grad_and_vars in enumerate(grads):
with tf.device(raw_devices[idx]): with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads]
# apply_gradients may create variables. Make them LOCAL_VARIABLES # apply_gradients may create variables. Make them LOCAL_VARIABLES
with override_to_local_variable(enable=idx > 0): with override_to_local_variable(enable=idx > 0):
train_ops.append(opt.apply_gradients( train_ops.append(opt.apply_gradients(
......
...@@ -9,7 +9,12 @@ import tensorflow as tf ...@@ -9,7 +9,12 @@ import tensorflow as tf
__all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable', __all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable',
'override_to_local_variable'] 'override_to_local_variable', 'allreduce_grads', 'average_grads']
"""
Some utilities for building the graph.
"""
@contextmanager @contextmanager
...@@ -73,3 +78,66 @@ class LeastLoadedDeviceSetter(object): ...@@ -73,3 +78,66 @@ class LeastLoadedDeviceSetter(object):
def __str__(self): def __str__(self):
return "LeastLoadedDeviceSetter-{}".format(self.worker_device) return "LeastLoadedDeviceSetter-{}".format(self.worker_device)
def allreduce_grads(all_grads):
"""
All-reduce average the gradients among devices. Results are broadcasted to all devices.
Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be the same across the K lists.
Returns:
(K x N x 2): same as input, but each grad is replaced by the average over K lists.
"""
from tensorflow.contrib import nccl
nr_tower = len(all_grads)
if nr_tower == 1:
return [[x] for x in all_grads[0]]
new_all_grads = [] # NVar * NGPU * 2
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*all_grads):
v = grad_and_vars[0][1]
grads = [g for g, _ in grad_and_vars]
summed = nccl.all_sum(grads)
grads_for_a_var = []
for (_, v), g in zip(grad_and_vars, summed):
with tf.device(g.device):
g = tf.multiply(g, 1.0 / nr_tower)
grads_for_a_var.append((g, v))
new_all_grads.append(grads_for_a_var)
# transpose
ret = [k for k in zip(*new_all_grads)]
return ret
def average_grads(all_grads):
"""
Average the gradients, on the device of each variable.
Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be the same across the K lists.
Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged over K.
"""
nr_tower = len(all_grads)
if nr_tower == 1:
return all_grads[0]
ret = []
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*all_grads):
# Ngpu * 2
v = grad_and_vars[0][1]
all_grads = [g for (g, _) in grad_and_vars]
with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_tower)
ret.append((grad, v))
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment