Commit 05a424b1 authored by Yuxin Wu's avatar Yuxin Wu

Try average by copy in ReplicatedTrainer

parent e086f05a
...@@ -221,13 +221,14 @@ if __name__ == '__main__': ...@@ -221,13 +221,14 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data', required=True) parser.add_argument('--data', required=True)
parser.add_argument('--batch', type=int, default=32) parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--aug', choices=['train', 'val'], default='val')
args = parser.parse_args() args = parser.parse_args()
augs = fbresnet_augmentor(False) if args.aug == 'val':
augs = [imgaug.ResizeShortestEdge(256), augs = fbresnet_augmentor(False)
imgaug.CenterCrop(224) elif args.aug == 'train':
] augs = fbresnet_augmentor(True)
df = get_imagenet_dataflow( df = get_imagenet_dataflow(
args.data, 'train', args.batch, augs) args.data, 'train', args.batch, augs)
# For val augmentor, Should get >100 it/s (i.e. 3k im/s) here on a decent E5 server.
TestDataSpeed(df).start() TestDataSpeed(df).start()
...@@ -190,7 +190,16 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -190,7 +190,16 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
use_vs=[False] + [True] * (len(self.towers) - 1)) use_vs=[False] + [True] * (len(self.towers) - 1))
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
grads = allreduce_grads(grad_list)
if True:
grads = allreduce_grads(grad_list) # #gpu x #param x 2
else:
agg_grad_and_vars = average_grads(grad_list, colocation=False, devices=['/cpu:0']) # #param x 2
grads = [] # #gpu x #param x 2
for grad_and_vars in grad_list: # grad_and_vars: #paramx2
# take v from each tower, and g from average.
grads.append(
[(g, v) for (_, v), (g, _) in zip(grad_and_vars, agg_grad_and_vars)])
train_ops = [] train_ops = []
opt = get_opt_fn() opt = get_opt_fn()
...@@ -201,6 +210,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -201,6 +210,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
train_ops.append(opt.apply_gradients( train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(idx))) grad_and_vars, name='apply_grad_{}'.format(idx)))
train_op = tf.group(*train_ops, name='train_op') train_op = tf.group(*train_ops, name='train_op')
post_init_op = SyncMultiGPUReplicatedBuilder.get_post_init_ops() post_init_op = SyncMultiGPUReplicatedBuilder.get_post_init_ops()
return train_op, post_init_op return train_op, post_init_op
......
...@@ -110,6 +110,7 @@ def allreduce_grads(all_grads): ...@@ -110,6 +110,7 @@ def allreduce_grads(all_grads):
grads_for_a_var = [] grads_for_a_var = []
for (_, v), g in zip(grad_and_vars, summed): for (_, v), g in zip(grad_and_vars, summed):
with tf.device(g.device): with tf.device(g.device):
# tensorflow/benchmarks didn't average gradients
g = tf.multiply(g, 1.0 / nr_tower) g = tf.multiply(g, 1.0 / nr_tower)
grads_for_a_var.append((g, v)) grads_for_a_var.append((g, v))
new_all_grads.append(grads_for_a_var) new_all_grads.append(grads_for_a_var)
...@@ -119,25 +120,30 @@ def allreduce_grads(all_grads): ...@@ -119,25 +120,30 @@ def allreduce_grads(all_grads):
return ret return ret
def average_grads(all_grads, colocation=True): def average_grads(all_grads, colocation=True, devices=None):
""" """
Average the gradients, on the device of each variable. Average the gradients.
Args: Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples. all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be the same across the K lists. The variables have to be the same across the K lists.
colocation (bool): colocate gradient averaging with the variable colocation (bool): colocate gradient averaging on the device of the variable.
devices (list[str]): assign the averaging to these device in
round-robin. Cannot be used together with ``colocation``.
Returns: Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged over K. (N x 2): A list of N (grad, var) tuples, where grad is averaged over K.
""" """
assert not (devices is not None and colocation)
if devices is not None:
assert isinstance(devices, list), devices
nr_tower = len(all_grads) nr_tower = len(all_grads)
if nr_tower == 1: if nr_tower == 1:
return all_grads[0] return all_grads[0]
ret = [] ret = []
with tf.name_scope('AvgGrad'): with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*all_grads): for idx, grad_and_vars in enumerate(zip(*all_grads)):
# Ngpu * 2 # Ngpu * 2
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
grads = [g for (g, _) in grad_and_vars] grads = [g for (g, _) in grad_and_vars]
...@@ -146,9 +152,14 @@ def average_grads(all_grads, colocation=True): ...@@ -146,9 +152,14 @@ def average_grads(all_grads, colocation=True):
with tf.device(v.device): # colocate summed grad with var with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply( grad = tf.multiply(
tf.add_n(grads), 1.0 / nr_tower) tf.add_n(grads), 1.0 / nr_tower)
else: elif devices is None:
grad = tf.multiply( grad = tf.multiply(
tf.add_n(grads), 1.0 / nr_tower) tf.add_n(grads), 1.0 / nr_tower)
else:
dev = devices[idx % len(devices)]
with tf.device(dev):
grad = tf.multiply(
tf.add_n(grads), 1.0 / nr_tower)
ret.append((grad, v)) ret.append((grad, v))
return ret return ret
......
...@@ -344,7 +344,7 @@ class TensorInput(FeedfreeInput): ...@@ -344,7 +344,7 @@ class TensorInput(FeedfreeInput):
class DummyConstantInput(TensorInput): class DummyConstantInput(TensorInput):
""" Input with some random tensor placed on GPU. """ Input with a constant zero tensor placed on GPU.
Useful for debugging performance issues """ Useful for debugging performance issues """
def __init__(self, shapes): def __init__(self, shapes):
""" """
......
...@@ -233,14 +233,11 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -233,14 +233,11 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
fused=False) fused=False)
xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope()) xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope())
if ctx.has_own_variables: if ctx.is_main_training_tower:
# Only apply update in this case.
# Add these EMA to model_variables so that they will be synced
# properly by replicated trainers.
for v in layer.non_trainable_variables: for v in layer.non_trainable_variables:
add_model_variable(v) add_model_variable(v)
else: else:
# Don't need update if we are sharing variables from an existing tower # only run UPDATE_OPS in the first tower
restore_collection(coll_bk) restore_collection(coll_bk)
if ndims == 2: if ndims == 2:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment