Commit 1021b385 authored by Yuxin Wu's avatar Yuxin Wu

async training. accumulate on variable device

parent c420730d
......@@ -325,9 +325,10 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
train_ops = []
opt = self.model.get_optimizer()
for i in range(self.config.nr_tower):
with tf.device(raw_devices[i]):
grad_and_vars = grad_list[i]
for i, grad_and_vars in enumerate(zip(*grad_list)):
# Ngpu x 2
v = grad_and_vars[0][1]
with tf.device(v.device):
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(i)))
self.train_op = tf.group(*train_ops, name='train_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment