Commit 6151e048 authored by Yuxin Wu's avatar Yuxin Wu

rewrite allreduce and avoid bug in TF's nccl

parent dbc0b36e
...@@ -116,5 +116,5 @@ if __name__ == '__main__': ...@@ -116,5 +116,5 @@ if __name__ == '__main__':
trainer = HorovodTrainer(average=False) trainer = HorovodTrainer(average=False)
else: else:
# nccl mode appears faster than cpu mode # nccl mode appears faster than cpu mode
trainer = SyncMultiGPUTrainerReplicated(cfg.TRAIN.NUM_GPUS, average=False, mode='nccl') trainer = SyncMultiGPUTrainerReplicated(cfg.TRAIN.NUM_GPUS, average=False)
launch_train_with_config(traincfg, trainer) launch_train_with_config(traincfg, trainer)
...@@ -8,7 +8,7 @@ from ..tfutils.common import get_global_step_var, get_op_tensor_name ...@@ -8,7 +8,7 @@ from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
from .training import DataParallelBuilder, GraphBuilder from .training import DataParallelBuilder, GraphBuilder
from .utils import OverrideCachingDevice, aggregate_grads, override_to_local_variable from .utils import OverrideCachingDevice, split_grad_list, allreduce_grads_naive, override_to_local_variable
__all__ = [] __all__ = []
...@@ -123,7 +123,9 @@ class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderB ...@@ -123,7 +123,9 @@ class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderB
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
grads = aggregate_grads(grad_list, colocation=False) all_grads, all_vars = split_grad_list(grad_list)
all_grads = allreduce_grads_naive(all_grads)
grads = [(g, v) for g, v in zip(all_grads, all_vars[0])]
opt = get_opt_fn() opt = get_opt_fn()
train_op = opt.apply_gradients(grads, name='train_op') train_op = opt.apply_gradients(grads, name='train_op')
train_op = self._add_sync_queues_and_barrier('all_workers_sync_barrier', [train_op]) train_op = self._add_sync_queues_and_barrier('all_workers_sync_barrier', [train_op])
...@@ -285,8 +287,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -285,8 +287,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
use_vs=[True] * len(self.towers)) # open vs at each tower use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
avg_grads = aggregate_grads( all_grads, all_vars = split_grad_list(grad_list)
grad_list, colocation=False, devices=self.raw_devices) avg_grads = allreduce_grads_naive(all_grads, devices=self.raw_devices) # N
avg_grads = [(g, v) for g, v in zip(all_grads, all_vars[0])]
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads) ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy( var_update_ops = self._apply_gradients_and_copy(
......
...@@ -16,7 +16,9 @@ from ..tfutils.tower import TrainTowerContext ...@@ -16,7 +16,9 @@ from ..tfutils.tower import TrainTowerContext
from ..utils import logger from ..utils import logger
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from .utils import ( from .utils import (
GradientPacker, LeastLoadedDeviceSetter, aggregate_grads, allreduce_grads, allreduce_grads_hierarchical, GradientPacker, LeastLoadedDeviceSetter,
aggregate_grads_colocate, allreduce_grads_naive,
allreduce_grads, allreduce_grads_hierarchical,
merge_grad_list, override_to_local_variable, split_grad_list) merge_grad_list, override_to_local_variable, split_grad_list)
__all__ = ["DataParallelBuilder"] __all__ = ["DataParallelBuilder"]
...@@ -173,12 +175,13 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -173,12 +175,13 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
assert len(grad_list) == len(self.towers) assert len(grad_list) == len(self.towers)
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
# debug tower performance (without update): # debug tower performance:
# ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]] # ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
# self.train_op = tf.group(*ops) # self.train_op = tf.group(*ops)
# return # return
self.grads = aggregate_grads(grad_list, colocation=True) self.grads = aggregate_grads_colocate(grad_list)
# debug tower performance:
# grads = grad_list[0] # grads = grad_list[0]
opt = get_opt_fn() opt = get_opt_fn()
...@@ -204,13 +207,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -204,13 +207,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
def __init__(self, towers, average, mode): def __init__(self, towers, average, mode):
super(SyncMultiGPUReplicatedBuilder, self).__init__(towers) super(SyncMultiGPUReplicatedBuilder, self).__init__(towers)
self._average = average self._average = average
assert mode in ['nccl', 'cpu', 'hierarchical'], mode assert mode in ['nccl', 'cpu', 'hierarchical', 'gpu', 'collective'], mode
if get_tf_version_tuple() >= (2, 0) and mode == 'cpu':
mode = 'nccl' # cpu mode causes the entire model to get located on cpu
self._mode = mode self._mode = mode
if self._mode == 'hierarchical' and len(towers) != 8: if self._mode == 'hierarchical' and len(towers) != 8:
logger.warn("mode='hierarchical' require >= 8 GPUs. Fallback to mode='nccl'.") logger.warn("mode='hierarchical' require 8 GPUs. Fallback to mode='nccl'.")
self._mode = 'nccl' self._mode = 'nccl'
def call_for_each_tower(self, tower_fn): def call_for_each_tower(self, tower_fn):
...@@ -257,39 +258,38 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -257,39 +258,38 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
valid_for_nccl = all(k in dtypes_nccl_supported for k in dtypes) valid_for_nccl = all(k in dtypes_nccl_supported for k in dtypes)
if self._mode == 'nccl' and not valid_for_nccl: if self._mode == 'nccl' and not valid_for_nccl:
logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'") logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'")
self._mode = 'cpu' self._mode = 'gpu'
if self._mode in ['nccl', 'hierarchical']:
all_grads, all_vars = split_grad_list(grad_list) all_grads, all_vars = split_grad_list(grad_list)
def do_allreduce(all_grads):
# use allreduce from tf-benchmarks # use allreduce from tf-benchmarks
# from .batch_allreduce import AllReduceSpecAlgorithm # from .batch_allreduce import AllReduceSpecAlgorithm
# algo = AllReduceSpecAlgorithm('nccl', list(range(8)), 0, 10) # algo = AllReduceSpecAlgorithm('nccl', list(range(8)), 0, 10)
# all_grads, warmup_ops = algo.batch_all_reduce(all_grads, 1, True, False) # all_grads, warmup_ops = algo.batch_all_reduce(all_grads, 1, True, False)
# print("WARMUP OPS", warmup_ops) # print("WARMUP OPS", warmup_ops)
if self._mode in ['nccl', 'collective']:
if self._mode == 'nccl': # #gpu x #param
all_grads = allreduce_grads(all_grads, average=self._average) # #gpu x #param all_grads = allreduce_grads(all_grads, average=self._average, mode=self._mode)
elif self._mode == 'hierarchical':
all_grads = allreduce_grads_hierarchical(all_grads, raw_devices, average=self._average)
else: else:
devices = ['/cpu:0'] if self._mode == 'cpu' else raw_devices
all_grads = allreduce_grads_naive(all_grads, devices=devices, average=self._average)
all_grads = [all_grads] * len(self.towers)
return all_grads
use_packer = self._mode in ['hierarchical']
if use_packer:
packer = GradientPacker(len(raw_devices)) packer = GradientPacker(len(raw_devices))
succ = packer.compute_strategy(all_grads[0]) use_packer = packer.compute_strategy(all_grads[0]) # may fail to pack
if succ: if use_packer:
packed_grads = packer.pack_all(all_grads, raw_devices) all_grads = packer.pack_all(all_grads, raw_devices)
packed_grads_aggr = allreduce_grads_hierarchical( all_grads = do_allreduce(all_grads) # all the work happens here
packed_grads, raw_devices, average=self._average) if use_packer:
all_grads = packer.unpack_all(packed_grads_aggr, raw_devices) all_grads = packer.unpack_all(all_grads, raw_devices)
else:
all_grads = allreduce_grads_hierarchical(all_grads, raw_devices, average=self._average)
self.grads = merge_grad_list(all_grads, all_vars) self.grads = merge_grad_list(all_grads, all_vars)
elif self._mode == 'cpu':
agg_grad_and_vars = aggregate_grads(
grad_list, colocation=False,
devices=['/cpu:0'], average=self._average) # #param x 2
self.grads = [] # #gpu x #param x 2
for grad_and_vars in grad_list: # grad_and_vars: #paramx2
# take v from each tower, and g from average.
self.grads.append(
[(g, v) for (_, v), (g, _) in zip(grad_and_vars, agg_grad_and_vars)])
train_ops = [] train_ops = []
opt = get_opt_fn() opt = get_opt_fn()
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import operator import operator
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
import threading
from ..compat import tfv1 from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
...@@ -13,7 +14,7 @@ from ..tfutils.varreplace import custom_getter_scope ...@@ -13,7 +14,7 @@ from ..tfutils.varreplace import custom_getter_scope
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
__all__ = ["LeastLoadedDeviceSetter", "allreduce_grads", "aggregate_grads"] __all__ = ["LeastLoadedDeviceSetter", "allreduce_grads"]
""" """
...@@ -33,6 +34,19 @@ def _replace_global_by_local(kwargs): ...@@ -33,6 +34,19 @@ def _replace_global_by_local(kwargs):
kwargs['collections'] = list(collections) kwargs['collections'] = list(collections)
_module_lock = threading.Lock()
_shared_cnt_counter = 0
def _get_shared_cnt():
global _shared_cnt_counter
with _module_lock:
val = _shared_cnt_counter
_shared_cnt_counter += 1
return val
@contextmanager @contextmanager
def override_to_local_variable(enable=True): def override_to_local_variable(enable=True):
""" """
...@@ -84,17 +98,18 @@ class LeastLoadedDeviceSetter(object): ...@@ -84,17 +98,18 @@ class LeastLoadedDeviceSetter(object):
if op.type not in ['Variable', 'VariableV2']: if op.type not in ['Variable', 'VariableV2']:
return canonicalize(self.worker_device) return canonicalize(self.worker_device)
device_index, _ = min(enumerate( device_name = self.place_with_balance(op)
self.ps_sizes), key=operator.itemgetter(1)) return canonicalize(device_name)
def place_with_balance(self, op):
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index] device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements() var_size = op.outputs[0].get_shape().num_elements()
if var_size is None: if var_size is None:
logger.warn("[LeastLoadedDeviceSetter] Shape of variable {} is not fully defined!".format(op.name)) logger.warn("[LeastLoadedDeviceSetter] Shape of variable {} is not fully defined!".format(op.name))
var_size = 0 var_size = 0
self.ps_sizes[device_index] += var_size self.ps_sizes[device_index] += var_size
return device_name
return canonicalize(device_name)
def __str__(self): def __str__(self):
return "LeastLoadedDeviceSetter-{}".format(self.worker_device) return "LeastLoadedDeviceSetter-{}".format(self.worker_device)
...@@ -130,28 +145,42 @@ def merge_grad_list(all_grads, all_vars): ...@@ -130,28 +145,42 @@ def merge_grad_list(all_grads, all_vars):
@under_name_scope('AllReduceGrads') @under_name_scope('AllReduceGrads')
def allreduce_grads(all_grads, average): def allreduce_grads(all_grads, average, mode="nccl"):
""" """
All-reduce average the gradients among K devices. Results are broadcasted to all devices. All-reduce average the gradients among K devices. Results are broadcasted to all devices.
Args: Args:
all_grads (K x N): List of list of gradients. N is the number of variables. all_grads (K x N): List of list of gradients. N is the number of variables.
average (bool): average gradients or not. average (bool): average gradients or not.
mode (str): "nccl", "collective"
Returns: Returns:
K x N: same as input, but each grad is replaced by the average over K devices. K x N: same as input, but each grad is replaced by the average over K devices.
""" """
assert mode in ["nccl", "collective"], mode
if get_tf_version_tuple() <= (1, 12):
from tensorflow.contrib import nccl # deprecated
else:
from tensorflow.python.ops import nccl_ops as nccl
nr_tower = len(all_grads) nr_tower = len(all_grads)
if nr_tower == 1: if nr_tower == 1:
return all_grads return all_grads
new_all_grads = [] # N x K new_all_grads = [] # N x K
for grads in zip(*all_grads): for grads in zip(*all_grads):
# k grads
if mode == "nccl":
if get_tf_version_tuple() <= (1, 12):
from tensorflow.contrib import nccl # deprecated
else:
from tensorflow.python.ops import nccl_ops as nccl
summed = nccl.all_sum(grads) summed = nccl.all_sum(grads)
else:
from tensorflow.python.ops import collective_ops
summed = []
shared_cnt = _get_shared_cnt()
for t in grads:
with tf.device(t.device):
t = collective_ops.all_reduce(
t, len(grads), shared_cnt, shared_cnt + 100,
'Add', 'Id')
summed.append(t)
grads_for_devices = [] # K grads_for_devices = [] # K
for g in summed: for g in summed:
...@@ -229,29 +258,18 @@ def allreduce_grads_hierarchical(all_grads, devices, average=False): ...@@ -229,29 +258,18 @@ def allreduce_grads_hierarchical(all_grads, devices, average=False):
return agg_all_grads return agg_all_grads
@under_name_scope('AggregateGrads') @under_name_scope('AggregateGradsColocate')
def aggregate_grads(all_grads, def aggregate_grads_colocate(all_grads, average=True):
colocation=False,
devices=None,
average=True):
""" """
Average the gradients. Aggregate the gradients. The aggregation is colocated with the variable.
Args: Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples. all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be the same across the K lists. The variables have to be shared across the K lists.
colocation (bool): colocate gradient averaging on the device of the variable.
devices (list[str]): assign the averaging to these device in
round-robin. Cannot be used together with ``colocation``.
average (bool): do average or sum average (bool): do average or sum
Returns: Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged or summed over K. (N x 2): A list of N (grad, var) tuples, where grad is averaged or summed over K.
""" """
assert not (devices is not None and colocation)
if devices is not None:
assert isinstance(devices, list), devices
nr_tower = len(all_grads) nr_tower = len(all_grads)
if nr_tower == 1: if nr_tower == 1:
return all_grads[0] return all_grads[0]
...@@ -267,21 +285,57 @@ def aggregate_grads(all_grads, ...@@ -267,21 +285,57 @@ def aggregate_grads(all_grads,
# Ngpu * 2 # Ngpu * 2
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
grads = [g for (g, _) in grad_and_vars] grads = [g for (g, _) in grad_and_vars]
if colocation:
with tf.device(v.device): # colocate summed grad with var with tf.device(v.device): # colocate summed grad with var
grad = aggregate(grads) grad = aggregate(grads)
elif devices is None: ret.append((grad, v))
return ret
@under_name_scope('AllReduceNaive')
def allreduce_grads_naive(all_grads, devices=None, average=True):
"""
AllReduce the gradients with raw ops (instead of collective ops).
Args:
all_grads (K x N): A list of K lists. Each of the list is a list of N grad tuples.
The variables have to be the same across the K lists.
devices (list[str]): assign the averaging to these device in
round-robin. Cannot be used together with ``colocation``.
average (bool): do average or sum
Returns:
list[Tensor]: list of grads where each grad is averaged or summed over K.
"""
if devices is not None:
assert isinstance(devices, list), devices
# device_setter = LeastLoadedDeviceSetter(None, devices)
nr_tower = len(all_grads)
if nr_tower == 1:
return all_grads[0]
def aggregate(grads):
if average:
return tf.multiply(tf.add_n(grads), 1.0 / nr_tower)
else:
return tf.add_n(grads)
grads_ret = [] # N(rev) grads
# reverse so the device placement makes the last part of model more balance?
all_grads_rev = [x[::-1] for x in all_grads] # K x N(rev)
for idx, grads in enumerate(zip(*all_grads_rev)):
# grads: K tensors
if devices is None:
grad = aggregate(grads) grad = aggregate(grads)
else: else:
# dev = device_setter.place_with_balance(v.op)
dev = devices[idx % len(devices)] dev = devices[idx % len(devices)]
with tf.device(dev): with tf.device(dev):
grad = aggregate(grads) grad = aggregate(grads)
ret.append((grad, v)) grads_ret.append(grad)
return ret grads_ret = grads_ret[::-1]
return grads_ret
average_grads = aggregate_grads
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166 # https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166
...@@ -319,6 +373,8 @@ class OverrideCachingDevice(object): ...@@ -319,6 +373,8 @@ class OverrideCachingDevice(object):
return var return var
# TODO pack at variable boundary, so that the concat does not have to wait for all
# grads to be ready
class GradientPacker(object): class GradientPacker(object):
""" """
Concat gradients together to optimize transfer. Concat gradients together to optimize transfer.
......
...@@ -290,6 +290,9 @@ class Trainer(object): ...@@ -290,6 +290,9 @@ class Trainer(object):
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Detected Ctrl-C and exiting main loop.") logger.info("Detected Ctrl-C and exiting main loop.")
raise raise
except Exception:
logger.error("Training failed at global_step=", self.loop.global_step)
raise
finally: finally:
self._callbacks.after_train() self._callbacks.after_train()
self.hooked_sess.close() self.hooked_sess.close()
......
...@@ -117,7 +117,7 @@ class ModelDesc(ModelDescBase): ...@@ -117,7 +117,7 @@ class ModelDesc(ModelDescBase):
""" """
ret = self.optimizer() ret = self.optimizer()
assert isinstance(ret, tfv1.train.Optimizer), \ assert isinstance(ret, tfv1.train.Optimizer), \
"ModelDesc.optimizer() must return a tf.train.Optimizer! Got {} instead.".format(str(ret)) "ModelDesc.optimizer() must return an instance of tf.train.Optimizer! Got {} instead.".format(str(ret))
return ret return ret
def optimizer(self): def optimizer(self):
......
...@@ -13,6 +13,7 @@ from ..graph_builder.training import ( ...@@ -13,6 +13,7 @@ from ..graph_builder.training import (
from ..graph_builder.utils import override_to_local_variable from ..graph_builder.utils import override_to_local_variable
from ..input_source import FeedfreeInput, QueueInput from ..input_source import FeedfreeInput, QueueInput
from ..tfutils import get_global_step_var from ..tfutils import get_global_step_var
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.distributed import get_distributed_session_creator from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.tower import TrainTowerContext from ..tfutils.tower import TrainTowerContext
...@@ -173,10 +174,26 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -173,10 +174,26 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
"hierarchical" mode was designed for DGX-like 8GPU machines. "hierarchical" mode was designed for DGX-like 8GPU machines.
""" """
self.devices = gpus self.devices = gpus
if mode is not None:
mode = mode.lower()
# Heuristics about mode selection:
if mode == 'hierarchical' and len(gpus) != 8:
logger.warn("mode='hierarchical' requires 8 GPUs. Will fallback to default mode.")
mode = None
if mode is None: if mode is None:
mode = 'hierarchical' if len(gpus) == 8 else 'nccl' if len(gpus) == 8:
mode = mode.lower() mode = 'hierarchical'
else:
# https://github.com/tensorflow/tensorflow/issues/41539
mode = 'nccl' if get_tf_version_tuple() < (1, 15) else 'gpu'
if mode == 'cpu' and get_tf_version_tuple() >= (2, 0):
# cpu mode causes the entire model to get located on cpu
mode = 'gpu'
if mode == 'nccl' and get_tf_version_tuple() >= (1, 15):
logger.warning(
"NCCL in TensorFlow has a serious bug that is likely to trigger in TF>=1.15. "
"Try 'mode=None' to use a better default mode.")
self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, mode) self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, mode)
self.BROADCAST_EVERY_EPOCH = True self.BROADCAST_EVERY_EPOCH = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment