Commit 0d7e71df authored by Yuxin Wu's avatar Yuxin Wu

Add gradient packer for allreduce.

parent 5ad33556
......@@ -16,8 +16,8 @@ from ..tfutils.gradproc import ScaleGradient
from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, aggregate_grads, allreduce_hierarchical,
split_grad_list, merge_grad_list)
allreduce_grads, aggregate_grads, allreduce_grads_hierarchical,
split_grad_list, merge_grad_list, GradientPacker)
__all__ = ['GraphBuilder',
......@@ -61,7 +61,7 @@ class DataParallelBuilder(GraphBuilder):
inters &= s
for s in names_per_gpu:
s -= inters
logger.error("Unique variables on towers: " + pprint.pformat(names_per_gpu))
logger.error("Unique trainable variables on towers: " + pprint.pformat(names_per_gpu))
raise ValueError("Number of gradients from each tower is different! " + str(nvars))
@staticmethod
......@@ -185,7 +185,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
def __init__(self, towers, average, mode):
super(SyncMultiGPUReplicatedBuilder, self).__init__(towers)
self._average = average
assert mode in ['nccl', 'cpu'], mode
assert mode in ['nccl', 'cpu', 'hierarchical'], mode
self._mode = mode
def build(self, get_grad_fn, get_opt_fn):
......@@ -205,6 +205,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
"""
raw_devices = ['/gpu:{}'.format(k) for k in self.towers]
# #GPU x #VAR x 2
grad_list = DataParallelBuilder.build_on_towers(
self.towers,
get_grad_fn,
......@@ -213,12 +214,23 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
DataParallelBuilder._check_grad_list(grad_list)
if self._mode == 'nccl':
if self._mode == 'hierarchical' and len(raw_devices) < 8:
logger.warn("mode='hierarchical' require >= 8 GPUs. Fallback to mode='cpu'.")
self._mode = 'cpu'
if self._mode in ['nccl', 'hierarchical']:
all_grads, all_vars = split_grad_list(grad_list)
if True:
if self._mode == 'nccl':
all_grads = allreduce_grads(all_grads, average=self._average) # #gpu x #param x 2
else:
all_grads = allreduce_hierarchical(all_grads, raw_devices, average=self._average)
# all_grads = allreduce_grads_hierarchical(all_grads, raw_devices, average=self._average)
packer = GradientPacker(len(raw_devices))
packer.compute_strategy(all_grads[0])
packed_grads = packer.pack_all(all_grads, raw_devices)
packed_grads_aggr = allreduce_grads_hierarchical(packed_grads, raw_devices, average=self._average)
all_grads = packer.unpack_all(packed_grads_aggr, raw_devices)
self.grads = merge_grad_list(all_grads, all_vars)
elif self._mode == 'cpu':
agg_grad_and_vars = aggregate_grads(
......
......@@ -8,7 +8,9 @@ import operator
import tensorflow as tf
from ..tfutils.varreplace import custom_getter_scope
from ..tfutils.scope_utils import under_name_scope
from ..tfutils.scope_utils import under_name_scope, cached_name_scope
from ..utils.argtools import call_only_once
from ..utils import logger
__all__ = ['LeastLoadedDeviceSetter',
......@@ -149,7 +151,7 @@ def allreduce_grads(all_grads, average):
@under_name_scope('AllReduceGradsHierachical')
def allreduce_hierarchical(all_grads, devices, average=False):
def allreduce_grads_hierarchical(all_grads, devices, average=False):
"""
Hierarchical allreduce for DGX-1 system.
......@@ -298,3 +300,77 @@ class OverrideCachingDevice(object):
kwargs['caching_device'] = device_name
var = getter(*args, **kwargs)
return var
class GradientPacker(object):
"""
Concat gradients together to optimize transfer.
"""
def __init__(self, num_split=8):
self._num_split = num_split
@call_only_once
def compute_strategy(self, grads):
for g in grads:
assert g.shape.is_fully_defined(), "Shape of {} is {}!".format(g.name, g.shape)
self._shapes = [g.shape for g in grads]
self._sizes = [g.shape.num_elements() for g in grads]
self._total_size = sum(self._sizes)
assert self._total_size > self._num_split
# should have the same dtype
dtypes = set([g.dtype for g in grads])
assert len(dtypes) == 1, dtypes
split_size = self._total_size // self._num_split
split_size_last = self._total_size - split_size * (self._num_split - 1)
self._split_sizes = [split_size] * (self._num_split - 1) + [split_size_last]
logger.info(
"Will pack {} gradients of total number={} into {} splits.".format(
len(self._sizes), self._total_size, self._num_split))
def pack(self, grads):
"""
Args:
grads (list): list of gradient tensors
Returns:
packed list of gradient tensors to be aggregated.
"""
for i, g in enumerate(grads):
assert g.shape == self._shapes[i]
with cached_name_scope("GradientPacker", top_level=False):
concat_grads = tf.concat([tf.reshape(g, [-1]) for g in grads], 0, name='concatenated_grads')
grad_packs = tf.split(concat_grads, self._split_sizes)
return grad_packs
def unpack(self, grad_packs):
with cached_name_scope("GradientPacker", top_level=False):
concat_grads = tf.concat(grad_packs, 0, name='concatenated_packs')
flattened_grads = tf.split(concat_grads, self._sizes)
grads = [tf.reshape(g, shape) for g, shape in zip(flattened_grads, self._shapes)]
return grads
def pack_all(self, all_grads, devices):
"""
Args:
all_grads: K x N, K lists of gradients to be packed
"""
ret = [] # #GPU x #split
for dev, grads in zip(devices, all_grads):
with tf.device(dev):
ret.append(self.pack(grads))
return ret
def unpack_all(self, all_packed, devices):
"""
Args:
all_packed: K lists of packed gradients.
"""
all_grads = [] # #GPU x #Var
for dev, packed_grads_single_device in zip(devices, all_packed):
with tf.device(dev):
all_grads.append(self.unpack(packed_grads_single_device))
return all_grads
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment