Commit 8f8f9179 authored by Yuxin Wu's avatar Yuxin Wu

fix the broadcast stage of horovod trainer.

parent 822997c7
......@@ -380,9 +380,9 @@ class HorovodTrainer(SingleCostTrainer):
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op')
with tf.name_scope('horovod_broadcast'):
op = hvd.broadcast_global_variables(0)
self._broadcast_op = hvd.broadcast_global_variables(0)
cb = RunOp(
op, run_before=True,
self._broadcast_op, run_before=False,
run_as_trigger=True, verbose=True)
return [cb]
......@@ -398,8 +398,12 @@ class HorovodTrainer(SingleCostTrainer):
session_creator.config.inter_op_parallelism_threads = mp.cpu_count() // hvd.local_size()
except AttributeError: # old horovod does not have local_size
pass
super(HorovodTrainer, self).initialize(
session_creator, session_init)
super(HorovodTrainer, self).initialize(session_creator, session_init)
# This broadcast belongs to the "intialize" stage
# It should not be delayed to the "before_train" stage.
logger.info("Broadcasting initialized variables ...")
self.sess.run(self._broadcast_op)
from ..utils.develop import create_dummy_class # noqa
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment