Commit d4a432ad authored by Yuxin Wu's avatar Yuxin Wu

Refactor gradient aggregation in replicated trainer

parent e44e9c04
...@@ -376,6 +376,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -376,6 +376,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'PrefetchOnGPUs', 'PrefetchOnGPUs',
'PeriodicRunHooks', 'PeriodicRunHooks',
'apply_default_prefetch', 'apply_default_prefetch',
'average_grads',
'Deconv2D',
'saliency_map', 'get_scalar_var', 'psnr', 'saliency_map', 'get_scalar_var', 'psnr',
'prediction_incorrect', 'huber_loss', 'SoftMax' 'prediction_incorrect', 'huber_loss', 'SoftMax'
......
...@@ -12,7 +12,7 @@ from ..tfutils.common import get_op_tensor_name, get_global_step_var ...@@ -12,7 +12,7 @@ from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .training import GraphBuilder, DataParallelBuilder from .training import GraphBuilder, DataParallelBuilder
from .utils import ( from .utils import (
override_to_local_variable, average_grads, override_to_local_variable, aggregate_grads,
OverrideCachingDevice) OverrideCachingDevice)
__all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder'] __all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder']
...@@ -126,7 +126,7 @@ class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderB ...@@ -126,7 +126,7 @@ class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderB
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
grads = average_grads(grad_list, colocation=False) grads = aggregate_grads(grad_list, colocation=False)
opt = get_opt_fn() opt = get_opt_fn()
train_op = opt.apply_gradients(grads, name='train_op') train_op = opt.apply_gradients(grads, name='train_op')
train_op = self._add_sync_queues_and_barrier('all_workers_sync_barrier', [train_op]) train_op = self._add_sync_queues_and_barrier('all_workers_sync_barrier', [train_op])
...@@ -287,7 +287,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -287,7 +287,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
use_vs=[True] * len(self.towers)) # open vs at each tower use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
avg_grads = average_grads( avg_grads = aggregate_grads(
grad_list, colocation=False, devices=self.raw_devices) grad_list, colocation=False, devices=self.raw_devices)
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads) ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads)
......
...@@ -16,7 +16,7 @@ from ..tfutils.gradproc import ScaleGradient ...@@ -16,7 +16,7 @@ from ..tfutils.gradproc import ScaleGradient
from .utils import ( from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable, LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, average_grads) allreduce_grads, aggregate_grads)
__all__ = ['GraphBuilder', __all__ = ['GraphBuilder',
...@@ -154,7 +154,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -154,7 +154,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
# self.train_op = tf.group(*ops) # self.train_op = tf.group(*ops)
# return # return
self.grads = average_grads(grad_list, colocation=True) self.grads = aggregate_grads(grad_list, colocation=True)
# grads = grad_list[0] # grads = grad_list[0]
opt = get_opt_fn() opt = get_opt_fn()
...@@ -181,10 +181,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -181,10 +181,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
Though on different deviecs, they should contain the same value. Though on different deviecs, they should contain the same value.
""" """
def __init__(self, towers, average, use_nccl): def __init__(self, towers, average, mode):
super(SyncMultiGPUReplicatedBuilder, self).__init__(towers) super(SyncMultiGPUReplicatedBuilder, self).__init__(towers)
self._average = average self._average = average
self._use_nccl = use_nccl assert mode in ['nccl', 'cpu'], mode
self._mode = mode
def build(self, get_grad_fn, get_opt_fn): def build(self, get_grad_fn, get_opt_fn):
""" """
...@@ -211,10 +212,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -211,10 +212,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
if self._use_nccl: if self._mode == 'nccl':
self.grads = allreduce_grads(grad_list, average=self._average) # #gpu x #param x 2 self.grads = allreduce_grads(grad_list, average=self._average) # #gpu x #param x 2
else: elif self._mode == 'cpu':
agg_grad_and_vars = average_grads( agg_grad_and_vars = aggregate_grads(
grad_list, colocation=False, grad_list, colocation=False,
devices=['/cpu:0'], average=self._average) # #param x 2 devices=['/cpu:0'], average=self._average) # #param x 2
self.grads = [] # #gpu x #param x 2 self.grads = [] # #gpu x #param x 2
......
...@@ -14,7 +14,8 @@ __all__ = ['LeastLoadedDeviceSetter', ...@@ -14,7 +14,8 @@ __all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice', 'OverrideCachingDevice',
'override_to_local_variable', 'override_to_local_variable',
'allreduce_grads', 'allreduce_grads',
'average_grads'] 'average_grads',
'aggregate_grads']
""" """
...@@ -119,7 +120,10 @@ def allreduce_grads(all_grads, average): ...@@ -119,7 +120,10 @@ def allreduce_grads(all_grads, average):
return ret return ret
def average_grads(all_grads, colocation=True, devices=None, average=True): def aggregate_grads(all_grads,
colocation=False,
devices=None,
average=True):
""" """
Average the gradients. Average the gradients.
...@@ -132,7 +136,7 @@ def average_grads(all_grads, colocation=True, devices=None, average=True): ...@@ -132,7 +136,7 @@ def average_grads(all_grads, colocation=True, devices=None, average=True):
average (bool): do average or sum average (bool): do average or sum
Returns: Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged over K. (N x 2): A list of N (grad, var) tuples, where grad is averaged or summed over K.
""" """
assert not (devices is not None and colocation) assert not (devices is not None and colocation)
if devices is not None: if devices is not None:
...@@ -149,7 +153,7 @@ def average_grads(all_grads, colocation=True, devices=None, average=True): ...@@ -149,7 +153,7 @@ def average_grads(all_grads, colocation=True, devices=None, average=True):
return tf.add_n(grads) return tf.add_n(grads)
ret = [] ret = []
with tf.name_scope('AvgGrad'): with tf.name_scope('AggregateGrad'):
for idx, grad_and_vars in enumerate(zip(*all_grads)): for idx, grad_and_vars in enumerate(zip(*all_grads)):
# Ngpu * 2 # Ngpu * 2
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
...@@ -168,6 +172,9 @@ def average_grads(all_grads, colocation=True, devices=None, average=True): ...@@ -168,6 +172,9 @@ def average_grads(all_grads, colocation=True, devices=None, average=True):
return ret return ret
average_grads = aggregate_grads
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166 # https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166
class OverrideCachingDevice(object): class OverrideCachingDevice(object):
"""Variable getter which caches variables on the least loaded device. """Variable getter which caches variables on the least loaded device.
......
...@@ -140,15 +140,19 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -140,15 +140,19 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
""" """
@map_arg(gpus=_int_to_range) @map_arg(gpus=_int_to_range)
def __init__(self, gpus, average=True, use_nccl=True): def __init__(self, gpus, average=True, mode='nccl', use_nccl=None):
""" """
Args: Args:
gpus (int or [int]): list of GPU ids. gpus (int or [int]): list of GPU ids.
average (bool): whether to average or sum gradients. average (bool): whether to average or sum gradients.
use_nccl (bool): use NCCL or TensorFlow copy to reduce. mode (str): Gradient aggregation mode. Supported values: ['nccl', 'cpu']
""" """
self.devices = gpus self.devices = gpus
self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, use_nccl) if use_nccl is not None:
mode = 'nccl' if use_nccl else 'cpu'
logger.warn("use_nccl option was deprecated! Use the `mode` option instead!")
mode = mode.lower()
self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, mode)
super(SyncMultiGPUTrainerReplicated, self).__init__() super(SyncMultiGPUTrainerReplicated, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment