Commit 86c9df35 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 74ca05dc
...@@ -34,6 +34,7 @@ class NewSessionCreator(tf.train.ChiefSessionCreator): ...@@ -34,6 +34,7 @@ class NewSessionCreator(tf.train.ChiefSessionCreator):
else: else:
self.user_provided_config = True self.user_provided_config = True
self.config = config
super(NewSessionCreator, self).__init__(master=target, config=config) super(NewSessionCreator, self).__init__(master=target, config=config)
......
...@@ -211,15 +211,16 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -211,15 +211,16 @@ class DistributedTrainerReplicated(SingleCostTrainer):
class HorovodTrainer(SingleCostTrainer): class HorovodTrainer(SingleCostTrainer):
"""
Horovod trainer, currently support multi-GPU training.
It will use the first k GPUs in CUDA_VISIBLE_DEVICES.
"""
def __init__(self): def __init__(self):
hvd.init() hvd.init()
self.is_chief = hvd.rank() == 0 self.is_chief = hvd.rank() == 0
local_rank = hvd.local_rank() self._local_rank = hvd.local_rank()
devices = os.environ['CUDA_VISIBLE_DEVICES'] logger.info("Horovod local rank={}".format(self._local_rank))
devices = list(map(int, devices.split(',')))
assert len(devices) >= local_rank
self._device = devices[local_rank]
logger.info("Horovod local rank={}, device={}".format(local_rank, self._device))
super(HorovodTrainer, self).__init__() super(HorovodTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
...@@ -239,7 +240,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -239,7 +240,7 @@ class HorovodTrainer(SingleCostTrainer):
if not isinstance(session_creator, NewSessionCreator): if not isinstance(session_creator, NewSessionCreator):
raise ValueError( raise ValueError(
"Cannot set session_creator for horovod training! ") "Cannot set session_creator for horovod training! ")
session_creator._config.gpu_options.visible_device_list = str(self._device) session_creator.config.gpu_options.visible_device_list = str(self._local_rank)
super(HorovodTrainer, self).initialize( super(HorovodTrainer, self).initialize(
session_creator, session_init) session_creator, session_init)
......
# $File: Makefile # $File: Makefile
# $Date: Thu Aug 03 16:14:29 2017 -0700 # $Date: Tue Oct 31 11:44:27 2017 +0800
OBJ_DIR = obj OBJ_DIR = obj
...@@ -40,6 +40,7 @@ ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g') ...@@ -40,6 +40,7 @@ ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g')
OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o)) OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d) DEPFILES = $(OBJS:.o=.d)
# TODO what about mac?
SO = $(ccSOURCES:.cc=.so) SO = $(ccSOURCES:.cc=.so)
.PHONY: all clean .PHONY: all clean
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: common.py # File: common.py
from __future__ import print_function from __future__ import print_function
import sysconfig
import tensorflow as tf import tensorflow as tf
import os import os
...@@ -16,5 +17,19 @@ def compile(): ...@@ -16,5 +17,19 @@ def compile():
return ret return ret
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def get_ext_suffix():
"""Determine library extension for various versions of Python."""
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
if ext_suffix:
return ext_suffix
ext_suffix = sysconfig.get_config_var('SO')
if ext_suffix:
return ext_suffix
return '.so'
if __name__ == '__main__': if __name__ == '__main__':
compile() compile()
...@@ -11,12 +11,13 @@ from tensorflow.core.framework.tensor_pb2 import TensorProto ...@@ -11,12 +11,13 @@ from tensorflow.core.framework.tensor_pb2 import TensorProto
from tensorflow.core.framework import types_pb2 as DataType from tensorflow.core.framework import types_pb2 as DataType
# have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce # have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce
from .common import compile from .common import compile, get_ext_suffix
__all__ = ['zmq_recv', 'dumps_for_tfop', __all__ = ['zmq_recv', 'dumps_for_tfop',
'dump_tensor_protos', 'to_tensor_proto'] 'dump_tensor_protos', 'to_tensor_proto']
# TODO '.so' for linux only
def build(): def build():
global zmq_recv global zmq_recv
ret = compile() ret = compile()
...@@ -25,7 +26,7 @@ def build(): ...@@ -25,7 +26,7 @@ def build():
else: else:
file_dir = os.path.dirname(os.path.abspath(__file__)) file_dir = os.path.dirname(os.path.abspath(__file__))
recv_mod = tf.load_op_library( recv_mod = tf.load_op_library(
os.path.join(file_dir, 'zmq_recv_op.so')) os.path.join(file_dir, 'zmq_recv_op.' + get_ext_suffix()))
zmq_recv = recv_mod.zmq_recv zmq_recv = recv_mod.zmq_recv
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment