Commit 74ca05dc authored by Yuxin Wu's avatar Yuxin Wu

horovod now works on multigpu (#422)

parent 1a262e8c
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: trainers.py # File: trainers.py
import os import os
import tensorflow as tf
from ..callbacks import RunOp from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
...@@ -213,7 +214,12 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -213,7 +214,12 @@ class HorovodTrainer(SingleCostTrainer):
def __init__(self): def __init__(self):
hvd.init() hvd.init()
self.is_chief = hvd.rank() == 0 self.is_chief = hvd.rank() == 0
logger.info("Horovod local rank: {}".format(hvd.local_rank())) local_rank = hvd.local_rank()
devices = os.environ['CUDA_VISIBLE_DEVICES']
devices = list(map(int, devices.split(',')))
assert len(devices) >= local_rank
self._device = devices[local_rank]
logger.info("Horovod local rank={}, device={}".format(local_rank, self._device))
super(HorovodTrainer, self).__init__() super(HorovodTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
...@@ -222,10 +228,20 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -222,10 +228,20 @@ class HorovodTrainer(SingleCostTrainer):
opt = get_opt_fn() opt = get_opt_fn()
opt = hvd.DistributedOptimizer(opt) opt = hvd.DistributedOptimizer(opt)
self.train_op = opt.apply_gradients(grads, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
return [RunOp( cb = RunOp(
hvd.broadcast_global_variables(0), tf.identity(hvd.broadcast_global_variables(0), name='horovod_broadcast_global_variables'),
run_before=True, run_before=True,
run_as_trigger=False, verbose=True)] run_as_trigger=False, verbose=True)
cb.chief_only = False
return [cb]
def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator):
raise ValueError(
"Cannot set session_creator for horovod training! ")
session_creator._config.gpu_options.visible_device_list = str(self._device)
super(HorovodTrainer, self).initialize(
session_creator, session_init)
from ..utils.develop import create_dummy_class # noqa from ..utils.develop import create_dummy_class # noqa
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment