Commit 8f8f9179 authored by Yuxin Wu's avatar Yuxin Wu

fix the broadcast stage of horovod trainer.

parent 822997c7
...@@ -380,9 +380,9 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -380,9 +380,9 @@ class HorovodTrainer(SingleCostTrainer):
opt = get_opt_fn() opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
with tf.name_scope('horovod_broadcast'): with tf.name_scope('horovod_broadcast'):
op = hvd.broadcast_global_variables(0) self._broadcast_op = hvd.broadcast_global_variables(0)
cb = RunOp( cb = RunOp(
op, run_before=True, self._broadcast_op, run_before=False,
run_as_trigger=True, verbose=True) run_as_trigger=True, verbose=True)
return [cb] return [cb]
...@@ -398,8 +398,12 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -398,8 +398,12 @@ class HorovodTrainer(SingleCostTrainer):
session_creator.config.inter_op_parallelism_threads = mp.cpu_count() // hvd.local_size() session_creator.config.inter_op_parallelism_threads = mp.cpu_count() // hvd.local_size()
except AttributeError: # old horovod does not have local_size except AttributeError: # old horovod does not have local_size
pass pass
super(HorovodTrainer, self).initialize( super(HorovodTrainer, self).initialize(session_creator, session_init)
session_creator, session_init)
# This broadcast belongs to the "intialize" stage
# It should not be delayed to the "before_train" stage.
logger.info("Broadcasting initialized variables ...")
self.sess.run(self._broadcast_op)
from ..utils.develop import create_dummy_class # noqa from ..utils.develop import create_dummy_class # noqa
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment