Commit 2d6d7ad4 authored by Yuxin Wu's avatar Yuxin Wu

Add Distributed PS Trainer. Missing some features such as saving model vars (#493)

parent 3145729a
...@@ -9,13 +9,106 @@ from six.moves import zip, range ...@@ -9,13 +9,106 @@ from six.moves import zip, range
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.common import get_op_tensor_name, get_global_step_var from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .training import DataParallelBuilder from .training import GraphBuilder, DataParallelBuilder
from .utils import override_to_local_variable from .utils import (
override_to_local_variable, average_grads,
OverrideCachingDevice)
__all__ = ['DistributedReplicatedBuilder'] __all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder']
class DistributedReplicatedBuilder(DataParallelBuilder): class DistributedBuilderBase(GraphBuilder):
_sync_queue_counter = 0
def __init__(self, server):
self.server = server
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.task_index = server_def.task_index
self.num_ps = self.cluster.num_tasks('ps')
self.num_worker = self.cluster.num_tasks('worker')
def _add_sync_queues_and_barrier(self, name, dependencies):
"""Adds ops to enqueue on all worker queues.
Args:
name: prefixed for the shared_name of ops.
dependencies: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self._sync_queue_counter += 1
with tf.device(self.sync_queue_devices[self._sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [
tf.FIFOQueue(self.num_worker, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name, i))
for i in range(self.num_worker)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can finish this step.
token = tf.constant(False)
with tf.control_dependencies(dependencies):
for i, q in enumerate(sync_queues):
if i != self.task_index:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops.append(
sync_queues[self.task_index].dequeue_many(len(sync_queues) - 1))
return tf.group(*queue_ops, name=name)
class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderBase):
def __init__(self, towers, server, caching_device):
DataParallelBuilder.__init__(self, towers)
DistributedBuilderBase.__init__(self, server)
assert caching_device in ['cpu', 'gpu'], caching_device
self.caching_device = caching_device
self.is_chief = (self.task_index == 0)
worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter(
worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
self.cpu_device = '%s/cpu:0' % worker_prefix
self.raw_devices = ['{}/gpu:{}'.format(worker_prefix, k) for k in self.towers]
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
def build(self, get_grad_fn, get_opt_fn):
ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(
self.num_ps, tf.contrib.training.byte_size_load_fn)
devices = [
tf.train.replica_device_setter(
worker_device=d,
cluster=self.cluster,
ps_strategy=ps_strategy) for d in self.raw_devices]
if self.caching_device == 'gpu':
caching_devices = self.raw_devices
else:
caching_devices = [self.cpu_device]
custom_getter = OverrideCachingDevice(
caching_devices, self.cpu_device, 1024 * 64)
with tf.variable_scope(tf.get_variable_scope(), custom_getter=custom_getter):
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices)
DataParallelBuilder._check_grad_list(grad_list)
with tf.device(self.param_server_device):
grads = average_grads(grad_list, colocation=False)
opt = get_opt_fn()
train_op = opt.apply_gradients(grads, name='train_op')
train_op = self._add_sync_queues_and_barrier('all_workers_sync_barrier', [train_op])
return train_op
class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
""" """
Distributed replicated training. Distributed replicated training.
Each worker process builds the same model on one or more GPUs. Each worker process builds the same model on one or more GPUs.
...@@ -62,27 +155,21 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -62,27 +155,21 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
The job_name must be 'worker' because 'ps' job doesn't need to The job_name must be 'worker' because 'ps' job doesn't need to
build any graph. build any graph.
""" """
super(DistributedReplicatedBuilder, self).__init__(towers) DataParallelBuilder.__init__(self, towers)
self.server = server DistributedBuilderBase.__init__(self, server)
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.task_index = server_def.task_index
self.is_chief = (self.task_index == 0) self.is_chief = (self.task_index == 0)
worker_prefix = '/job:worker/task:%s' % self.task_index worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter( self.param_server_device = tf.train.replica_device_setter(
worker_device=worker_prefix + '/cpu:0', cluster=self.cluster) worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
self.num_ps = self.cluster.num_tasks('ps')
self.num_worker = self.cluster.num_tasks('worker')
self.nr_gpu = len(self.towers) self.nr_gpu = len(self.towers)
self.cpu_device = '%s/cpu:0' % worker_prefix self.cpu_device = '%s/cpu:0' % worker_prefix
self.raw_devices = ['%s/%s:%i' % (worker_prefix, 'gpu', i) for i in range(self.nr_gpu)] self.raw_devices = ['%s/gpu:%i' % (worker_prefix, i) for i in towers]
# Device for queues for managing synchronization between servers # Device for queues for managing synchronization between servers
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)] self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0
@staticmethod @staticmethod
def _average_grads(tower_grads, devices): def _average_grads(tower_grads, devices):
...@@ -156,36 +243,6 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -156,36 +243,6 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower
return shadow_model_vars return shadow_model_vars
def _add_sync_queues_and_barrier(self, name, dependencies):
"""Adds ops to enqueue on all worker queues.
Args:
name: prefixed for the shared_name of ops.
dependencies: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self.sync_queue_counter += 1
with tf.device(self.sync_queue_devices[self.sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [
tf.FIFOQueue(self.num_worker, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name, i))
for i in range(self.num_worker)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can finish this step.
token = tf.constant(False)
with tf.control_dependencies(dependencies):
for i, q in enumerate(sync_queues):
if i != self.task_index:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops.append(
sync_queues[self.task_index].dequeue_many(len(sync_queues) - 1))
return tf.group(*queue_ops, name=name)
def build(self, get_grad_fn, get_opt_fn): def build(self, get_grad_fn, get_opt_fn):
""" """
Args: Args:
......
...@@ -169,11 +169,18 @@ class OverrideCachingDevice(object): ...@@ -169,11 +169,18 @@ class OverrideCachingDevice(object):
def __call__(self, getter, *args, **kwargs): def __call__(self, getter, *args, **kwargs):
size = tf.TensorShape(kwargs['shape']).num_elements() size = tf.TensorShape(kwargs['shape']).num_elements()
if size is None: if size is None or not kwargs.get('trainable', True):
# print(args, kwargs) # TODO
return getter(*args, **kwargs) collections = kwargs['collections']
if not kwargs.get('trainable', True): if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(*args, **kwargs) return getter(*args, **kwargs)
if size < self.small_variable_size_threshold: if size < self.small_variable_size_threshold:
device_name = self.device_for_small_variables device_name = self.device_for_small_variables
else: else:
......
...@@ -21,15 +21,19 @@ def get_distributed_session_creator(server): ...@@ -21,15 +21,19 @@ def get_distributed_session_creator(server):
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer() local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables() ready_op = tf.report_uninitialized_variables()
ready_for_local_init_op = tf.report_uninitialized_variables(tf.global_variables())
sm = tf.train.SessionManager( sm = tf.train.SessionManager(
local_init_op=local_init_op, local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph()) ready_op=ready_op,
ready_for_local_init_op=ready_for_local_init_op,
graph=tf.get_default_graph())
# to debug wrong variable collection # to debug wrong variable collection
# from pprint import pprint
# print("GLOBAL:") # print("GLOBAL:")
# print(tf.global_variables()) # pprint([(k.name, k.device) for k in tf.global_variables()])
# print("LOCAL:") # print("LOCAL:")
# print(tf.local_variables()) # pprint([(k.name, k.device) for k in tf.local_variables()])
class _Creator(tf.train.SessionCreator): class _Creator(tf.train.SessionCreator):
def create_session(self): def create_session(self):
......
...@@ -231,6 +231,8 @@ def add_moving_summary(*args, **kwargs): ...@@ -231,6 +231,8 @@ def add_moving_summary(*args, **kwargs):
ema_ops = [] ema_ops = []
for c in v: for c in v:
name = re.sub('tower[0-9]+/', '', c.op.name) name = re.sub('tower[0-9]+/', '', c.op.name)
# TODO colocate may affect distributed setting
# colocate variable with compute op implies that the variable should be local_vars
with G.colocate_with(c), tf.name_scope(None): with G.colocate_with(c), tf.name_scope(None):
if not c.dtype.is_floating: if not c.dtype.is_floating:
c = tf.cast(c, tf.float32) c = tf.cast(c, tf.float32)
......
...@@ -20,7 +20,7 @@ from ..graph_builder.training import ( ...@@ -20,7 +20,7 @@ from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder, SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder, SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder) AsyncMultiGPUBuilder)
from ..graph_builder.distributed import DistributedReplicatedBuilder from ..graph_builder.distributed import DistributedReplicatedBuilder, DistributedParameterServerBuilder
from ..graph_builder.utils import override_to_local_variable from ..graph_builder.utils import override_to_local_variable
from .tower import SingleCostTrainer from .tower import SingleCostTrainer
...@@ -31,6 +31,7 @@ __all__ = ['SimpleTrainer', ...@@ -31,6 +31,7 @@ __all__ = ['SimpleTrainer',
'SyncMultiGPUTrainerReplicated', 'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer', 'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer', 'AsyncMultiGPUTrainer',
'DistributedTrainerParameterServer',
'DistributedTrainerReplicated', 'DistributedTrainerReplicated',
'HorovodTrainer'] 'HorovodTrainer']
...@@ -156,6 +157,58 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -156,6 +157,58 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
return [cb] return [cb]
class DistributedTrainerParameterServer(SingleCostTrainer):
__doc__ = DistributedParameterServerBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, server, caching_device='cpu'):
"""
Args:
gpus ([int]): list of GPU ids.
"""
self.devices = gpus
self.server = server
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedParameterServerBuilder(gpus, server, caching_device)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
super(DistributedTrainerParameterServer, self).__init__()
if self.job_name == 'ps':
# ps shouldn't setup input either
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this function will never return tensorflow#4713
raise RuntimeError("This is a bug. Server.join() for ps should never return!")
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
@HIDE_DOC
def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator) or \
session_creator.user_provided_config:
raise ValueError(
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerParameterServer, self).initialize(
get_distributed_session_creator(self.server), session_init)
class DistributedTrainerReplicated(SingleCostTrainer): class DistributedTrainerReplicated(SingleCostTrainer):
__doc__ = DistributedReplicatedBuilder.__doc__ __doc__ = DistributedReplicatedBuilder.__doc__
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment