Commit f246dd59 authored by Yuxin Wu's avatar Yuxin Wu

fix SyncMultiGPUTrainerReplicated for single-GPU case

parent 8f4183e7
......@@ -249,7 +249,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
from tensorflow.contrib import nccl
nr_tower = len(tower_grads)
if nr_tower == 1:
return tower_grads[0]
return [[x] for x in tower_grads[0]]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
......@@ -284,6 +284,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
for idx in range(self.config.nr_tower):
with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads]
print(grad_and_vars)
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(idx)))
self.train_op = tf.group(*train_ops, name='train_op')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment