Commit 46991853 authored by Yuxin Wu's avatar Yuxin Wu

simplify code for distributed

parent 2d6d7ad4
......@@ -10,7 +10,7 @@ import tensorflow as tf
__all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice',
'OverrideToLocalVariable', 'override_to_local_variable',
'override_to_local_variable',
'allreduce_grads',
'average_grads']
......@@ -20,23 +20,7 @@ Some utilities for building the graph.
"""
@contextmanager
def override_to_local_variable(enable=True):
if enable:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
yield
else:
yield
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
def _replace_global_by_local(kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
......@@ -46,8 +30,23 @@ class OverrideToLocalVariable(object):
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
@contextmanager
def override_to_local_variable(enable=True):
if enable:
def custom_getter(getter, name, *args, **kwargs):
_replace_global_by_local(kwargs)
return getter(name, *args, **kwargs)
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=custom_getter):
yield
else:
yield
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L192-L218
class LeastLoadedDeviceSetter(object):
......@@ -170,15 +169,8 @@ class OverrideCachingDevice(object):
def __call__(self, getter, *args, **kwargs):
size = tf.TensorShape(kwargs['shape']).num_elements()
if size is None or not kwargs.get('trainable', True):
# TODO
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
# TODO a lot of vars won't be saved then
_replace_global_by_local(kwargs)
return getter(*args, **kwargs)
if size < self.small_variable_size_threshold:
......
......@@ -157,46 +157,23 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
return [cb]
class DistributedTrainerParameterServer(SingleCostTrainer):
__doc__ = DistributedParameterServerBuilder.__doc__
class DistributedTrainerBase(SingleCostTrainer):
devices = None
"""
List of GPU ids.
"""
# TODO use full device name instead of id
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, server, caching_device='cpu'):
"""
Args:
gpus ([int]): list of GPU ids.
"""
def __init__(self, gpus, server):
super(DistributedTrainerBase, self).__init__()
self.devices = gpus
self.server = server
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedParameterServerBuilder(gpus, server, caching_device)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
super(DistributedTrainerParameterServer, self).__init__()
if self.job_name == 'ps':
# ps shouldn't setup input either
logger.info("Running ps {}".format(self.server.server_def.task_index))
def join(self):
logger.info("Calling server.join() on {}:{}".format(self.job_name, self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this function will never return tensorflow#4713
raise RuntimeError("This is a bug. Server.join() for ps should never return!")
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
raise RuntimeError("This is a bug. Server.join() for should never return!")
@HIDE_DOC
def initialize(self, session_creator, session_init):
......@@ -205,18 +182,37 @@ class DistributedTrainerParameterServer(SingleCostTrainer):
raise ValueError(
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerParameterServer, self).initialize(
super(DistributedTrainerBase, self).initialize(
get_distributed_session_creator(self.server), session_init)
class DistributedTrainerReplicated(SingleCostTrainer):
class DistributedTrainerParameterServer(DistributedTrainerBase):
__doc__ = DistributedReplicatedBuilder.__doc__
__doc__ = DistributedParameterServerBuilder.__doc__
devices = None
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, server, caching_device='cpu'):
"""
List of GPU ids.
Args:
gpus ([int]): list of GPU ids.
"""
super(DistributedTrainerParameterServer, self).__init__(gpus, server)
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'ps':
self.join()
self._builder = DistributedParameterServerBuilder(gpus, server, caching_device)
self.is_chief = self._builder.is_chief
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
class DistributedTrainerReplicated(SingleCostTrainer):
__doc__ = DistributedReplicatedBuilder.__doc__
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, server):
......@@ -225,26 +221,13 @@ class DistributedTrainerReplicated(SingleCostTrainer):
gpus (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
"""
self.devices = gpus
self.server = server
self.job_name = server.server_def.job_name
super(DistributedTrainerReplicated, self).__init__(gpus, server)
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'ps':
self.join()
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(gpus, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
super(DistributedTrainerReplicated, self).__init__()
if self.job_name == 'ps':
# ps shouldn't setup input either
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this function will never return tensorflow#4713
raise RuntimeError("This is a bug. Server.join() for ps should never return!")
def _setup_input(self, inputs_desc, input):
with override_to_local_variable():
......@@ -276,16 +259,6 @@ class DistributedTrainerReplicated(SingleCostTrainer):
callbacks.append(cb)
return callbacks
@HIDE_DOC
def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator) or \
session_creator.user_provided_config:
raise ValueError(
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerReplicated, self).initialize(
get_distributed_session_creator(self.server), session_init)
@property
def _main_tower_vs_name(self):
return "tower0"
......
......@@ -4,5 +4,4 @@
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment