Commit 02e53f72 authored by Yuxin Wu's avatar Yuxin Wu

make add_moving_summary use local variables, so they are not broadcasted

parent 2d661d6d
...@@ -245,10 +245,11 @@ def add_moving_summary(*args, **kwargs): ...@@ -245,10 +245,11 @@ def add_moving_summary(*args, **kwargs):
assert x.get_shape().ndims == 0, \ assert x.get_shape().ndims == 0, \
"add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape()) "add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape())
from ..graph_builder.utils import override_to_local_variable
ema_ops = [] ema_ops = []
for c in args: for c in args:
name = re.sub('tower[0-9]+/', '', c.op.name) name = re.sub('tower[0-9]+/', '', c.op.name)
with tf.name_scope(None): with tf.name_scope(None), override_to_local_variable(True):
if not c.dtype.is_floating: if not c.dtype.is_floating:
c = tf.cast(c, tf.float32) c = tf.cast(c, tf.float32)
# assign_moving_average creates variables with op names, therefore clear ns first. # assign_moving_average creates variables with op names, therefore clear ns first.
......
...@@ -474,7 +474,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -474,7 +474,7 @@ class HorovodTrainer(SingleCostTrainer):
return [cb] return [cb]
def broadcast(self, _): def broadcast(self, _):
logger.info("Running broadcast ...") logger.info("Broadcasting {} global variables ...".format(self._num_global_variables))
# the op will be created in initialize() # the op will be created in initialize()
self.sess.run(self._broadcast_op) self.sess.run(self._broadcast_op)
...@@ -483,6 +483,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -483,6 +483,7 @@ class HorovodTrainer(SingleCostTrainer):
# broadcast_op should be the last setup_graph: it needs to be created # broadcast_op should be the last setup_graph: it needs to be created
# "right before" the graph is finalized, # "right before" the graph is finalized,
# because it needs to capture all the variables (which may be created by callbacks). # because it needs to capture all the variables (which may be created by callbacks).
self._num_global_variables = len(tf.global_variables())
self._broadcast_op = self.hvd.broadcast_global_variables(0) self._broadcast_op = self.hvd.broadcast_global_variables(0)
# it's important that our NewSessionCreator does not finalize the graph # it's important that our NewSessionCreator does not finalize the graph
...@@ -504,9 +505,10 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -504,9 +505,10 @@ class HorovodTrainer(SingleCostTrainer):
# 1. a allgather helper to concat strings # 1. a allgather helper to concat strings
# 2. check variables on each rank match each other, print warnings, and broadcast the common set. # 2. check variables on each rank match each other, print warnings, and broadcast the common set.
if self.is_chief: if self.is_chief:
logger.info("Broadcasting initialized variables ...") logger.info("Broadcasting initialization of {} global variables ...".format(self._num_global_variables))
else: else:
logger.info("Rank {} waiting for initialization broadcasting ...".format(self._rank)) logger.info("Rank {} waiting for initialization of {} variables ...".format(
self._rank, self._num_global_variables))
self.sess.run(self._broadcast_op) self.sess.run(self._broadcast_op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment