Commit 8915849e authored by Yuxin Wu's avatar Yuxin Wu

bugfix: persist collection order

parent 67de41d0
......@@ -134,10 +134,10 @@ def average_grads(all_grads):
for grad_and_vars in zip(*all_grads):
# Ngpu * 2
v = grad_and_vars[0][1]
all_grads = [g for (g, _) in grad_and_vars]
grads = [g for (g, _) in grad_and_vars]
with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_tower)
tf.add_n(grads), 1.0 / nr_tower)
ret.append((grad, v))
return ret
......@@ -155,6 +155,7 @@ class CollectionGuard(object):
"""
Get items from this collection that are added in the current tower.
"""
new = set(tf.get_collection(key))
new = tf.get_collection(key)
old = set(self.original.get(key, []))
return list(new - old)
# presist the order in new
return [x for x in new if x not in old]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment