Commit 0146287b authored by Yuxin Wu's avatar Yuxin Wu

add back horovod broadcast

parent f43309f0
......@@ -424,15 +424,14 @@ class HorovodTrainer(SingleCostTrainer):
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='train_op')
def broadcast(self):
logger.info("Running broadcast ...")
# the op will be created later in initialize()
self.trainer._broadcast_op.run()
# TODO provide a way to sync manually
cb = CallbackFactory(before_train=broadcast).set_chief_only(False)
cb = CallbackFactory(before_train=self.broadcast, trigger=self.broadcast).set_chief_only(False)
return [cb]
def broadcast(self, _):
logger.info("Running broadcast ...")
# the op will be created in initialize()
self.sess.run(self._broadcast_op)
@HIDE_DOC
def initialize(self, session_creator, session_init):
# broadcast_op should be the last setup_graph: it needs to be created
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment