Commit 8915849e authored by Yuxin Wu's avatar Yuxin Wu

bugfix: persist collection order

parent 67de41d0
...@@ -134,10 +134,10 @@ def average_grads(all_grads): ...@@ -134,10 +134,10 @@ def average_grads(all_grads):
for grad_and_vars in zip(*all_grads): for grad_and_vars in zip(*all_grads):
# Ngpu * 2 # Ngpu * 2
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
all_grads = [g for (g, _) in grad_and_vars] grads = [g for (g, _) in grad_and_vars]
with tf.device(v.device): # colocate summed grad with var with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply( grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_tower) tf.add_n(grads), 1.0 / nr_tower)
ret.append((grad, v)) ret.append((grad, v))
return ret return ret
...@@ -155,6 +155,7 @@ class CollectionGuard(object): ...@@ -155,6 +155,7 @@ class CollectionGuard(object):
""" """
Get items from this collection that are added in the current tower. Get items from this collection that are added in the current tower.
""" """
new = set(tf.get_collection(key)) new = tf.get_collection(key)
old = set(self.original.get(key, [])) old = set(self.original.get(key, []))
return list(new - old) # presist the order in new
return [x for x in new if x not in old]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment