Commit 1a262e8c authored by Yuxin Wu's avatar Yuxin Wu

Initial HorovodTrainer (#422)

parent 7b0782d6
...@@ -174,7 +174,7 @@ class Trainer(object): ...@@ -174,7 +174,7 @@ class Trainer(object):
logger.info("Initializing the session ...") logger.info("Initializing the session ...")
session_init._run_init(self.sess) session_init._run_init(self.sess)
else: else:
if not isinstance(self._config.session_init, JustCurrentSession): if not isinstance(session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!") logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize() self.sess.graph.finalize()
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
from ..callbacks.graph import RunOp from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger from ..utils import logger
...@@ -29,7 +29,8 @@ __all__ = ['SimpleTrainer', ...@@ -29,7 +29,8 @@ __all__ = ['SimpleTrainer',
'SyncMultiGPUTrainerReplicated', 'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer', 'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer', 'AsyncMultiGPUTrainer',
'DistributedTrainerReplicated'] 'DistributedTrainerReplicated',
'HorovodTrainer']
def _int_to_range(x): def _int_to_range(x):
...@@ -206,3 +207,29 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -206,3 +207,29 @@ class DistributedTrainerReplicated(SingleCostTrainer):
@property @property
def _main_tower_vs_name(self): def _main_tower_vs_name(self):
return "tower0" return "tower0"
class HorovodTrainer(SingleCostTrainer):
def __init__(self):
hvd.init()
self.is_chief = hvd.rank() == 0
logger.info("Horovod local rank: {}".format(hvd.local_rank()))
super(HorovodTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
with TowerContext('', is_training=True):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
opt = get_opt_fn()
opt = hvd.DistributedOptimizer(opt)
self.train_op = opt.apply_gradients(grads, name='min_op')
return [RunOp(
hvd.broadcast_global_variables(0),
run_before=True,
run_as_trigger=False, verbose=True)]
from ..utils.develop import create_dummy_class # noqa
try:
import horovod.tensorflow as hvd
except ImportError:
HorovodTrainer = create_dummy_class('HovorodTrainer', 'horovod') # noqa
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment