Commit fa0b0ca2 authored by Yuxin Wu's avatar Yuxin Wu

fix usage of collective_ops.nccl

parent 6151e048
...@@ -178,8 +178,10 @@ def allreduce_grads(all_grads, average, mode="nccl"): ...@@ -178,8 +178,10 @@ def allreduce_grads(all_grads, average, mode="nccl"):
for t in grads: for t in grads:
with tf.device(t.device): with tf.device(t.device):
t = collective_ops.all_reduce( t = collective_ops.all_reduce(
t, len(grads), shared_cnt, shared_cnt + 100, t, len(grads),
'Add', 'Id') 42, # group key is any fixed integer for a fixed group of devices
shared_cnt + 100,
'Add', 'Id', communication_hint='nccl')
summed.append(t) summed.append(t)
grads_for_devices = [] # K grads_for_devices = [] # K
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment