Commit 02e53f72 authored by Yuxin Wu's avatar Yuxin Wu

make add_moving_summary use local variables, so they are not broadcasted

parent 2d661d6d
......@@ -245,10 +245,11 @@ def add_moving_summary(*args, **kwargs):
assert x.get_shape().ndims == 0, \
"add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape())
from ..graph_builder.utils import override_to_local_variable
ema_ops = []
for c in args:
name = re.sub('tower[0-9]+/', '', c.op.name)
with tf.name_scope(None):
with tf.name_scope(None), override_to_local_variable(True):
if not c.dtype.is_floating:
c = tf.cast(c, tf.float32)
# assign_moving_average creates variables with op names, therefore clear ns first.
......
......@@ -474,7 +474,7 @@ class HorovodTrainer(SingleCostTrainer):
return [cb]
def broadcast(self, _):
logger.info("Running broadcast ...")
logger.info("Broadcasting {} global variables ...".format(self._num_global_variables))
# the op will be created in initialize()
self.sess.run(self._broadcast_op)
......@@ -483,6 +483,7 @@ class HorovodTrainer(SingleCostTrainer):
# broadcast_op should be the last setup_graph: it needs to be created
# "right before" the graph is finalized,
# because it needs to capture all the variables (which may be created by callbacks).
self._num_global_variables = len(tf.global_variables())
self._broadcast_op = self.hvd.broadcast_global_variables(0)
# it's important that our NewSessionCreator does not finalize the graph
......@@ -504,9 +505,10 @@ class HorovodTrainer(SingleCostTrainer):
# 1. a allgather helper to concat strings
# 2. check variables on each rank match each other, print warnings, and broadcast the common set.
if self.is_chief:
logger.info("Broadcasting initialized variables ...")
logger.info("Broadcasting initialization of {} global variables ...".format(self._num_global_variables))
else:
logger.info("Rank {} waiting for initialization broadcasting ...".format(self._rank))
logger.info("Rank {} waiting for initialization of {} variables ...".format(
self._rank, self._num_global_variables))
self.sess.run(self._broadcast_op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment