Commit d4a432ad authored by Yuxin Wu's avatar Yuxin Wu

Refactor gradient aggregation in replicated trainer

parent e44e9c04
......@@ -376,6 +376,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'PrefetchOnGPUs',
'PeriodicRunHooks',
'apply_default_prefetch',
'average_grads',
'Deconv2D',
'saliency_map', 'get_scalar_var', 'psnr',
'prediction_incorrect', 'huber_loss', 'SoftMax'
......
......@@ -12,7 +12,7 @@ from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .training import GraphBuilder, DataParallelBuilder
from .utils import (
override_to_local_variable, average_grads,
override_to_local_variable, aggregate_grads,
OverrideCachingDevice)
__all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder']
......@@ -126,7 +126,7 @@ class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderB
DataParallelBuilder._check_grad_list(grad_list)
with tf.device(self.param_server_device):
grads = average_grads(grad_list, colocation=False)
grads = aggregate_grads(grad_list, colocation=False)
opt = get_opt_fn()
train_op = opt.apply_gradients(grads, name='train_op')
train_op = self._add_sync_queues_and_barrier('all_workers_sync_barrier', [train_op])
......@@ -287,7 +287,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list)
avg_grads = average_grads(
avg_grads = aggregate_grads(
grad_list, colocation=False, devices=self.raw_devices)
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads)
......
......@@ -16,7 +16,7 @@ from ..tfutils.gradproc import ScaleGradient
from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, average_grads)
allreduce_grads, aggregate_grads)
__all__ = ['GraphBuilder',
......@@ -154,7 +154,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
# self.train_op = tf.group(*ops)
# return
self.grads = average_grads(grad_list, colocation=True)
self.grads = aggregate_grads(grad_list, colocation=True)
# grads = grad_list[0]
opt = get_opt_fn()
......@@ -181,10 +181,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
Though on different deviecs, they should contain the same value.
"""
def __init__(self, towers, average, use_nccl):
def __init__(self, towers, average, mode):
super(SyncMultiGPUReplicatedBuilder, self).__init__(towers)
self._average = average
self._use_nccl = use_nccl
assert mode in ['nccl', 'cpu'], mode
self._mode = mode
def build(self, get_grad_fn, get_opt_fn):
"""
......@@ -211,10 +212,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
DataParallelBuilder._check_grad_list(grad_list)
if self._use_nccl:
if self._mode == 'nccl':
self.grads = allreduce_grads(grad_list, average=self._average) # #gpu x #param x 2
else:
agg_grad_and_vars = average_grads(
elif self._mode == 'cpu':
agg_grad_and_vars = aggregate_grads(
grad_list, colocation=False,
devices=['/cpu:0'], average=self._average) # #param x 2
self.grads = [] # #gpu x #param x 2
......
......@@ -14,7 +14,8 @@ __all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice',
'override_to_local_variable',
'allreduce_grads',
'average_grads']
'average_grads',
'aggregate_grads']
"""
......@@ -119,7 +120,10 @@ def allreduce_grads(all_grads, average):
return ret
def average_grads(all_grads, colocation=True, devices=None, average=True):
def aggregate_grads(all_grads,
colocation=False,
devices=None,
average=True):
"""
Average the gradients.
......@@ -132,7 +136,7 @@ def average_grads(all_grads, colocation=True, devices=None, average=True):
average (bool): do average or sum
Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged over K.
(N x 2): A list of N (grad, var) tuples, where grad is averaged or summed over K.
"""
assert not (devices is not None and colocation)
if devices is not None:
......@@ -149,7 +153,7 @@ def average_grads(all_grads, colocation=True, devices=None, average=True):
return tf.add_n(grads)
ret = []
with tf.name_scope('AvgGrad'):
with tf.name_scope('AggregateGrad'):
for idx, grad_and_vars in enumerate(zip(*all_grads)):
# Ngpu * 2
v = grad_and_vars[0][1]
......@@ -168,6 +172,9 @@ def average_grads(all_grads, colocation=True, devices=None, average=True):
return ret
average_grads = aggregate_grads
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166
class OverrideCachingDevice(object):
"""Variable getter which caches variables on the least loaded device.
......
......@@ -140,15 +140,19 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
"""
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, average=True, use_nccl=True):
def __init__(self, gpus, average=True, mode='nccl', use_nccl=None):
"""
Args:
gpus (int or [int]): list of GPU ids.
average (bool): whether to average or sum gradients.
use_nccl (bool): use NCCL or TensorFlow copy to reduce.
mode (str): Gradient aggregation mode. Supported values: ['nccl', 'cpu']
"""
self.devices = gpus
self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, use_nccl)
if use_nccl is not None:
mode = 'nccl' if use_nccl else 'cpu'
logger.warn("use_nccl option was deprecated! Use the `mode` option instead!")
mode = mode.lower()
self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, mode)
super(SyncMultiGPUTrainerReplicated, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment