Commit 1a262e8c authored by Yuxin Wu's avatar Yuxin Wu

Initial HorovodTrainer (#422)

parent 7b0782d6
......@@ -174,7 +174,7 @@ class Trainer(object):
logger.info("Initializing the session ...")
session_init._run_init(self.sess)
else:
if not isinstance(self._config.session_init, JustCurrentSession):
if not isinstance(session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize()
......
......@@ -4,7 +4,7 @@
import os
from ..callbacks.graph import RunOp
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger
......@@ -29,7 +29,8 @@ __all__ = ['SimpleTrainer',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
'DistributedTrainerReplicated']
'DistributedTrainerReplicated',
'HorovodTrainer']
def _int_to_range(x):
......@@ -206,3 +207,29 @@ class DistributedTrainerReplicated(SingleCostTrainer):
@property
def _main_tower_vs_name(self):
return "tower0"
class HorovodTrainer(SingleCostTrainer):
def __init__(self):
hvd.init()
self.is_chief = hvd.rank() == 0
logger.info("Horovod local rank: {}".format(hvd.local_rank()))
super(HorovodTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
with TowerContext('', is_training=True):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
opt = get_opt_fn()
opt = hvd.DistributedOptimizer(opt)
self.train_op = opt.apply_gradients(grads, name='min_op')
return [RunOp(
hvd.broadcast_global_variables(0),
run_before=True,
run_as_trigger=False, verbose=True)]
from ..utils.develop import create_dummy_class # noqa
try:
import horovod.tensorflow as hvd
except ImportError:
HorovodTrainer = create_dummy_class('HovorodTrainer', 'horovod') # noqa
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment