Commit 2d6d7ad4 authored by Yuxin Wu's avatar Yuxin Wu

Add Distributed PS Trainer. Missing some features such as saving model vars (#493)

parent 3145729a
......@@ -9,13 +9,106 @@ from six.moves import zip, range
from ..utils.argtools import memoized
from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .training import DataParallelBuilder
from .utils import override_to_local_variable
from .training import GraphBuilder, DataParallelBuilder
from .utils import (
override_to_local_variable, average_grads,
OverrideCachingDevice)
__all__ = ['DistributedReplicatedBuilder']
__all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder']
class DistributedReplicatedBuilder(DataParallelBuilder):
class DistributedBuilderBase(GraphBuilder):
_sync_queue_counter = 0
def __init__(self, server):
self.server = server
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.task_index = server_def.task_index
self.num_ps = self.cluster.num_tasks('ps')
self.num_worker = self.cluster.num_tasks('worker')
def _add_sync_queues_and_barrier(self, name, dependencies):
"""Adds ops to enqueue on all worker queues.
Args:
name: prefixed for the shared_name of ops.
dependencies: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self._sync_queue_counter += 1
with tf.device(self.sync_queue_devices[self._sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [
tf.FIFOQueue(self.num_worker, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name, i))
for i in range(self.num_worker)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can finish this step.
token = tf.constant(False)
with tf.control_dependencies(dependencies):
for i, q in enumerate(sync_queues):
if i != self.task_index:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops.append(
sync_queues[self.task_index].dequeue_many(len(sync_queues) - 1))
return tf.group(*queue_ops, name=name)
class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderBase):
def __init__(self, towers, server, caching_device):
DataParallelBuilder.__init__(self, towers)
DistributedBuilderBase.__init__(self, server)
assert caching_device in ['cpu', 'gpu'], caching_device
self.caching_device = caching_device
self.is_chief = (self.task_index == 0)
worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter(
worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
self.cpu_device = '%s/cpu:0' % worker_prefix
self.raw_devices = ['{}/gpu:{}'.format(worker_prefix, k) for k in self.towers]
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
def build(self, get_grad_fn, get_opt_fn):
ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(
self.num_ps, tf.contrib.training.byte_size_load_fn)
devices = [
tf.train.replica_device_setter(
worker_device=d,
cluster=self.cluster,
ps_strategy=ps_strategy) for d in self.raw_devices]
if self.caching_device == 'gpu':
caching_devices = self.raw_devices
else:
caching_devices = [self.cpu_device]
custom_getter = OverrideCachingDevice(
caching_devices, self.cpu_device, 1024 * 64)
with tf.variable_scope(tf.get_variable_scope(), custom_getter=custom_getter):
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices)
DataParallelBuilder._check_grad_list(grad_list)
with tf.device(self.param_server_device):
grads = average_grads(grad_list, colocation=False)
opt = get_opt_fn()
train_op = opt.apply_gradients(grads, name='train_op')
train_op = self._add_sync_queues_and_barrier('all_workers_sync_barrier', [train_op])
return train_op
class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
"""
Distributed replicated training.
Each worker process builds the same model on one or more GPUs.
......@@ -62,27 +155,21 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
The job_name must be 'worker' because 'ps' job doesn't need to
build any graph.
"""
super(DistributedReplicatedBuilder, self).__init__(towers)
self.server = server
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.task_index = server_def.task_index
DataParallelBuilder.__init__(self, towers)
DistributedBuilderBase.__init__(self, server)
self.is_chief = (self.task_index == 0)
worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter(
worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
self.num_ps = self.cluster.num_tasks('ps')
self.num_worker = self.cluster.num_tasks('worker')
self.nr_gpu = len(self.towers)
self.cpu_device = '%s/cpu:0' % worker_prefix
self.raw_devices = ['%s/%s:%i' % (worker_prefix, 'gpu', i) for i in range(self.nr_gpu)]
self.raw_devices = ['%s/gpu:%i' % (worker_prefix, i) for i in towers]
# Device for queues for managing synchronization between servers
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0
@staticmethod
def _average_grads(tower_grads, devices):
......@@ -156,36 +243,6 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower
return shadow_model_vars
def _add_sync_queues_and_barrier(self, name, dependencies):
"""Adds ops to enqueue on all worker queues.
Args:
name: prefixed for the shared_name of ops.
dependencies: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self.sync_queue_counter += 1
with tf.device(self.sync_queue_devices[self.sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [
tf.FIFOQueue(self.num_worker, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name, i))
for i in range(self.num_worker)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can finish this step.
token = tf.constant(False)
with tf.control_dependencies(dependencies):
for i, q in enumerate(sync_queues):
if i != self.task_index:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops.append(
sync_queues[self.task_index].dequeue_many(len(sync_queues) - 1))
return tf.group(*queue_ops, name=name)
def build(self, get_grad_fn, get_opt_fn):
"""
Args:
......
......@@ -169,11 +169,18 @@ class OverrideCachingDevice(object):
def __call__(self, getter, *args, **kwargs):
size = tf.TensorShape(kwargs['shape']).num_elements()
if size is None:
# print(args, kwargs)
return getter(*args, **kwargs)
if not kwargs.get('trainable', True):
if size is None or not kwargs.get('trainable', True):
# TODO
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(*args, **kwargs)
if size < self.small_variable_size_threshold:
device_name = self.device_for_small_variables
else:
......
......@@ -21,15 +21,19 @@ def get_distributed_session_creator(server):
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables()
ready_for_local_init_op = tf.report_uninitialized_variables(tf.global_variables())
sm = tf.train.SessionManager(
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
ready_op=ready_op,
ready_for_local_init_op=ready_for_local_init_op,
graph=tf.get_default_graph())
# to debug wrong variable collection
# from pprint import pprint
# print("GLOBAL:")
# print(tf.global_variables())
# pprint([(k.name, k.device) for k in tf.global_variables()])
# print("LOCAL:")
# print(tf.local_variables())
# pprint([(k.name, k.device) for k in tf.local_variables()])
class _Creator(tf.train.SessionCreator):
def create_session(self):
......
......@@ -231,6 +231,8 @@ def add_moving_summary(*args, **kwargs):
ema_ops = []
for c in v:
name = re.sub('tower[0-9]+/', '', c.op.name)
# TODO colocate may affect distributed setting
# colocate variable with compute op implies that the variable should be local_vars
with G.colocate_with(c), tf.name_scope(None):
if not c.dtype.is_floating:
c = tf.cast(c, tf.float32)
......
......@@ -20,7 +20,7 @@ from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder)
from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.distributed import DistributedReplicatedBuilder, DistributedParameterServerBuilder
from ..graph_builder.utils import override_to_local_variable
from .tower import SingleCostTrainer
......@@ -31,6 +31,7 @@ __all__ = ['SimpleTrainer',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
'DistributedTrainerParameterServer',
'DistributedTrainerReplicated',
'HorovodTrainer']
......@@ -156,6 +157,58 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
return [cb]
class DistributedTrainerParameterServer(SingleCostTrainer):
__doc__ = DistributedParameterServerBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, server, caching_device='cpu'):
"""
Args:
gpus ([int]): list of GPU ids.
"""
self.devices = gpus
self.server = server
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedParameterServerBuilder(gpus, server, caching_device)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
super(DistributedTrainerParameterServer, self).__init__()
if self.job_name == 'ps':
# ps shouldn't setup input either
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this function will never return tensorflow#4713
raise RuntimeError("This is a bug. Server.join() for ps should never return!")
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
@HIDE_DOC
def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator) or \
session_creator.user_provided_config:
raise ValueError(
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerParameterServer, self).initialize(
get_distributed_session_creator(self.server), session_init)
class DistributedTrainerReplicated(SingleCostTrainer):
__doc__ = DistributedReplicatedBuilder.__doc__
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment