Commit f0573ed2 authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'distributed' (#144)

parents a3674b47 930481f2
...@@ -36,6 +36,8 @@ class Callback(object): ...@@ -36,6 +36,8 @@ class Callback(object):
.. automethod:: _after_train .. automethod:: _after_train
""" """
_chief_only = True
def setup_graph(self, trainer): def setup_graph(self, trainer):
self._steps_per_epoch = trainer.config.steps_per_epoch self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer self.trainer = trainer
...@@ -162,6 +164,19 @@ class Callback(object): ...@@ -162,6 +164,19 @@ class Callback(object):
def local_step(self): def local_step(self):
return self.trainer.local_step return self.trainer.local_step
@property
def chief_only(self):
"""
Only run this callback on chief training process.
Returns: bool
"""
return self._chief_only
@chief_only.setter
def chief_only(self, v):
self._chief_only = v
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
......
...@@ -17,13 +17,15 @@ class RunOp(Callback): ...@@ -17,13 +17,15 @@ class RunOp(Callback):
""" Run an Op. """ """ Run an Op. """
def __init__(self, setup_func, def __init__(self, setup_func,
run_before=True, run_as_trigger=True, run_step=False): run_before=True, run_as_trigger=True,
run_step=False, verbose=False):
""" """
Args: Args:
setup_func: a function that returns the Op in the graph setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training run_before (bool): run the Op before training
run_as_trigger (bool): run the Op on every trigger run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training) run_step (bool): run the Op every step (along with training)
verbose (bool): pring logs when the op is run.
Examples: Examples:
The `DQN Example The `DQN Example
...@@ -34,27 +36,38 @@ class RunOp(Callback): ...@@ -34,27 +36,38 @@ class RunOp(Callback):
self.run_before = run_before self.run_before = run_before
self.run_as_trigger = run_as_trigger self.run_as_trigger = run_as_trigger
self.run_step = run_step self.run_step = run_step
self.verbose = verbose
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
def _before_train(self): def _before_train(self):
if self.run_before: if self.run_before:
self._print()
self._op.run() self._op.run()
def _trigger(self): def _trigger(self):
if self.run_as_trigger: if self.run_as_trigger:
self._print()
self._op.run() self._op.run()
def _before_run(self, _): def _before_run(self, _):
if self.run_step: if self.run_step:
self._print()
return [self._op] return [self._op]
def _print(self):
if self.verbose:
logger.info("Running Op {} ...".format(self._op.name))
class RunUpdateOps(RunOp): class RunUpdateOps(RunOp):
""" """
Run ops from the collection UPDATE_OPS every step Run ops from the collection UPDATE_OPS every step
""" """
_chief_only = False
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS): def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
def f(): def f():
ops = tf.get_collection(collection) ops = tf.get_collection(collection)
......
...@@ -90,6 +90,7 @@ class InferenceRunnerBase(Callback): ...@@ -90,6 +90,7 @@ class InferenceRunnerBase(Callback):
def fn(_): def fn(_):
in_tensors = self._input_source.get_input_tensors() in_tensors = self._input_source.get_input_tensors()
self.trainer.model.build_graph(in_tensors) self.trainer.model.build_graph(in_tensors)
with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id) PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
......
...@@ -72,7 +72,7 @@ class GraphVarParam(HyperParam): ...@@ -72,7 +72,7 @@ class GraphVarParam(HyperParam):
self.var = v self.var = v
break break
else: else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name)) raise ValueError("{} is not a GLOBAL_VARIABLE in the graph!".format(self.var_name))
def set_value(self, v): def set_value(self, v):
""" Assign the variable a new value. """ """ Assign the variable a new value. """
......
...@@ -43,6 +43,7 @@ class ModelSaver(Callback): ...@@ -43,6 +43,7 @@ class ModelSaver(Callback):
vars = [] vars = []
for key in self.var_collections: for key in self.var_collections:
vars.extend(tf.get_collection(key)) vars.extend(tf.get_collection(key))
vars = list(set(vars))
self.path = os.path.join(self.checkpoint_dir, 'model') self.path = os.path.join(self.checkpoint_dir, 'model')
if get_tf_version_number() <= 1.1: if get_tf_version_number() <= 1.1:
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
......
...@@ -55,13 +55,14 @@ class MaintainStepCounter(Callback): ...@@ -55,13 +55,14 @@ class MaintainStepCounter(Callback):
# ensure it exists # ensure it exists
gs_var = get_global_step_var() gs_var = get_global_step_var()
with tf.name_scope(None): with tf.name_scope(None):
self.gs_incr_var = tf.assign_add( with tf.device(gs_var.device):
self.gs_incr_op = tf.assign_add(
gs_var, 1, gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME) name=GLOBAL_STEP_INCR_OP_NAME).op
# tf.mod( # tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch, # self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME) # name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_var) self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self): def _before_train(self):
gs_val = get_global_step_value() gs_val = get_global_step_value()
...@@ -81,6 +82,8 @@ class MaintainStepCounter(Callback): ...@@ -81,6 +82,8 @@ class MaintainStepCounter(Callback):
class ProgressBar(Callback): class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """ """ A progress bar based on tqdm. Enabled by default. """
_chief_only = False
def __init__(self, names=[]): def __init__(self, names=[]):
""" """
Args: Args:
......
...@@ -136,7 +136,7 @@ def layer_register( ...@@ -136,7 +136,7 @@ def layer_register(
# log shape info and add activation # log shape info and add activation
logger.info("{} output: {}".format( logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs))) scope.name, get_shape_str(outputs)))
_LAYER_LOGGED.add(scope.name) _LAYER_LOGGED.add(scope_name)
else: else:
# run the actual function # run the actual function
outputs = func(*args, **actual_args) outputs = func(*args, **actual_args)
......
...@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
for p in params: for p in params:
para_name = p.name para_name = p.name
# in replicated mode, only regularize variables inside this tower # in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and (not para_name.startswith(ctx.vs_name)): if ctx.has_own_variables and ctx.vs_name and (not para_name.startswith(ctx.vs_name)):
continue continue
if re.search(regex, para_name): if re.search(regex, para_name):
costs.append(func(p)) costs.append(func(p))
......
...@@ -39,9 +39,11 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -39,9 +39,11 @@ def get_default_sess_config(mem_fraction=0.99):
conf.inter_op_parallelism_threads = 0 conf.inter_op_parallelism_threads = 0
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
if get_tf_version_number() >= 1.2:
conf.gpu_options.force_gpu_compatible = True
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True conf.gpu_options.allow_growth = True
# force gpu compatible?
conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return conf return conf
......
...@@ -154,11 +154,13 @@ def add_moving_summary(v, *args, **kwargs): ...@@ -154,11 +154,13 @@ def add_moving_summary(v, *args, **kwargs):
for x in v: for x in v:
assert isinstance(x, tf.Tensor), x assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape() assert x.get_shape().ndims == 0, x.get_shape()
# TODO will produce tower0/xxx? # TODO will produce variable tower0/xxx?
# TODO not saved under distributed
# TODO use zero_debias # TODO use zero_debias
with tf.name_scope(None): gs = get_global_step_var()
with tf.name_scope(None), tf.device(gs.device):
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA') decay, num_updates=gs, name='EMA')
avg_maintain_op = averager.apply(v) avg_maintain_op = averager.apply(v)
for c in v: for c in v:
......
...@@ -17,13 +17,16 @@ class TowerContext(object): ...@@ -17,13 +17,16 @@ class TowerContext(object):
def __init__(self, tower_name, def __init__(self, tower_name,
device=None, is_training=None, device=None, is_training=None,
var_strategy='shared'): var_strategy='shared',
vs_name=None):
""" """
Args: Args:
tower_name (str): 'tower0', 'towerp0', or '' tower_name (str): 'tower0', 'towerp0', or ''
device (str or device function): the device to use. Defaults to either cpu0 or gpu0. device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
var_strategy (str): either 'shared' or 'replicated'. var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
""" """
self._name = tower_name self._name = tower_name
if device is None: if device is None:
...@@ -38,6 +41,13 @@ class TowerContext(object): ...@@ -38,6 +41,13 @@ class TowerContext(object):
self._var_strategy = var_strategy self._var_strategy = var_strategy
if self._var_strategy == 'replicated': if self._var_strategy == 'replicated':
assert self._name assert self._name
if vs_name is None:
self._vs_name = self._name
else:
self._vs_name = vs_name
else:
assert vs_name is None, "vs_name is only valid in 'replicated' mode!"
self._vs_name = ''
@property @property
def is_main_training_tower(self): def is_main_training_tower(self):
...@@ -62,12 +72,7 @@ class TowerContext(object): ...@@ -62,12 +72,7 @@ class TowerContext(object):
# variable_scope name # variable_scope name
@property @property
def vs_name(self): def vs_name(self):
if self.has_own_variables: return self._vs_name
# do not open new variable scope for the main tower,
# just use '', so that Saver & PredictTower know what to do
if self.index > 0:
return self._name
return ""
@property @property
def index(self): def index(self):
...@@ -113,13 +118,16 @@ class TowerContext(object): ...@@ -113,13 +118,16 @@ class TowerContext(object):
self._ctxs = [] self._ctxs = []
if len(self._name): if len(self._name):
if self.has_own_variables: if self.has_own_variables:
if self.vs_name: if len(self.vs_name):
self._ctxs.append(tf.variable_scope(self.vs_name)) self._ctxs.append(tf.variable_scope(self.vs_name))
else: else:
# use existing variable scope if self.is_training:
reuse = self.index > 0 or (not self.is_training) reuse = self.index > 0
if reuse is True:
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.variable_scope( self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=reuse)) tf.get_variable_scope(), reuse=True))
# if not training, should handle vs outside (TODO not good)
self._ctxs.append(tf.name_scope(self._name)) self._ctxs.append(tf.name_scope(self._name))
self._ctxs.append(tf.device(self._device)) self._ctxs.append(tf.device(self._device))
for c in self._ctxs: for c in self._ctxs:
......
...@@ -160,7 +160,7 @@ def get_checkpoint_path(model_path): ...@@ -160,7 +160,7 @@ def get_checkpoint_path(model_path):
new_path = model_path.split('.index')[0] new_path = model_path.split('.index')[0]
if new_path != model_path: if new_path != model_path:
logger.warn( logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path)) "Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
return model_path return model_path
...@@ -183,7 +183,8 @@ def dump_chkpt_vars(model_path): ...@@ -183,7 +183,8 @@ def dump_chkpt_vars(model_path):
def is_training_name(name): def is_training_name(name):
""" """
This is a hack temporarily used to improve logging. Do not use this function. Guess if a name belongs to a training-only variables.
Only used internally to avoid too many logging. Do not use it.
Returns: Returns:
bool: Guess whether this tensor is something only used in training. bool: Guess whether this tensor is something only used in training.
......
...@@ -9,8 +9,6 @@ import six ...@@ -9,8 +9,6 @@ import six
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from .predict import PredictorFactory from .predict import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
...@@ -21,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor ...@@ -21,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
__all__ = ['Trainer', 'StopTraining'] __all__ = ['Trainer', 'StopTraining']
...@@ -46,6 +45,9 @@ class Trainer(object): ...@@ -46,6 +45,9 @@ class Trainer(object):
local_step (int): the number of steps that have finished in the current epoch. local_step (int): the number of steps that have finished in the current epoch.
global_step (int): the number of steps that have finished. global_step (int): the number of steps that have finished.
""" """
# step attr only available after before_train?
is_chief = True
def __init__(self, config): def __init__(self, config):
""" """
...@@ -79,12 +81,18 @@ class Trainer(object): ...@@ -79,12 +81,18 @@ class Trainer(object):
assert isinstance(cb, Callback), cb assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \ assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!" "Cannot register more callbacks after trainer was setup!"
if not self.is_chief and cb.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(cb)))
else:
self._callbacks.append(cb) self._callbacks.append(cb)
def register_monitor(self, mon): def register_monitor(self, mon):
assert isinstance(mon, TrainingMonitor), mon assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self.monitors, Monitors), \ assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!" "Cannot register more monitors after trainer was setup!"
if not self.is_chief and mon.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(mon)))
else:
self.monitors.append(mon) self.monitors.append(mon)
self.register_callback(mon) self.register_callback(mon)
...@@ -110,6 +118,7 @@ class Trainer(object): ...@@ -110,6 +118,7 @@ class Trainer(object):
self.monitors = Monitors(self.monitors) self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors) self.register_callback(self.monitors)
# TODO cache per graph, avoid describing all towers
describe_model() describe_model()
# some final operations that might modify the graph # some final operations that might modify the graph
...@@ -117,21 +126,28 @@ class Trainer(object): ...@@ -117,21 +126,28 @@ class Trainer(object):
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
# create session
logger.info("Creating the session ...") logger.info("Creating the session ...")
self.sess = self.config.session_creator.create_session() self._create_session()
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
if self.is_chief:
logger.info("Initializing the session ...") logger.info("Initializing the session ...")
# init session
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
else:
assert isinstance(self.config.session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!"
self.sess.graph.finalize() self.sess.graph.finalize()
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks() hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks) self.sess = self.config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@abstractmethod @abstractmethod
def _setup(self): def _setup(self):
...@@ -154,12 +170,14 @@ class Trainer(object): ...@@ -154,12 +170,14 @@ class Trainer(object):
self._starting_step = get_global_step_value() self._starting_step = get_global_step_value()
try: try:
self._callbacks.before_train() self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
self._starting_step = get_global_step_value()
for self.epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time() start_time = time.time()
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
if self._monitored_sess.should_stop(): if self.hooked_sess.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
self._callbacks.trigger_step() self._callbacks.trigger_step()
...@@ -169,6 +187,7 @@ class Trainer(object): ...@@ -169,6 +187,7 @@ class Trainer(object):
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._trigger_epoch() self._trigger_epoch()
self._callbacks.trigger_epoch() self._callbacks.trigger_epoch()
logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError): except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.") logger.info("Training was stopped.")
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -177,7 +196,14 @@ class Trainer(object): ...@@ -177,7 +196,14 @@ class Trainer(object):
raise raise
finally: finally:
self._callbacks.after_train() self._callbacks.after_train()
self._monitored_sess.close() self.hooked_sess.close()
@property
def vs_name_for_predictor(self):
"""
The variable scope name a predictor should be built in.
"""
return ""
# Predictor related methods: TODO # Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0): def get_predictor(self, input_names, output_names, tower=0):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: distributed.py
import tensorflow as tf
import re
import os
from six.moves import range
from ..utils import logger
from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name
__all__ = ['DistributedReplicatedTrainer']
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"""
Distributed replicated training.
Each worker process builds the same model on one or more GPUs.
Gradients across GPUs are averaged within the worker,
and get synchronously applied to the global copy of variables located on PS.
Then each worker copy the latest variables from PS back to local.
Note:
Gradients are not averaged across workers.
"""
def __init__(self, config, server):
"""
Args:
config (TrainConfig): the train config.
server (tf.train.Server): the server object with ps and workers
"""
self.server = server
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.job_name = server_def.job_name
self.task_index = server_def.task_index
assert self.job_name in ['ps', 'worker'], self.job_name
assert tf.test.is_gpu_available
self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
super(DistributedReplicatedTrainer, self).__init__(config)
worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter(
worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
self.num_ps = self.cluster.num_tasks('ps')
self.num_worker = self.cluster.num_tasks('worker')
self.nr_gpu = config.nr_tower
self.cpu_device = '%s/cpu:0' % worker_prefix
self.raw_devices = ['%s/%s:%i' % (worker_prefix, 'gpu', i) for i in range(self.nr_gpu)]
# Device for queues for managing synchronization between servers
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0
@staticmethod
def _average_grads(tower_grads, devices):
"""
Average grad with round-robin device selection.
Args:
tower_grads: Ngpu x Nvar x 2
Returns:
Nvar x 2
"""
nr_device = len(devices)
if nr_device == 1:
return tower_grads[0]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for i, grad_and_vars in enumerate(zip(*tower_grads)):
# Ngpu * 2
with tf.device(devices[i % nr_device]):
v = grad_and_vars[0][1]
# average gradient
all_grads = [g for (g, _) in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
continue
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_device)
new_tower_grads.append((grad, v))
return new_tower_grads
@staticmethod
def _apply_shadow_vars(avg_grads):
"""
Replace variables in avg_grads by shadow variables.
"""
ps_var_grads = []
for grad, var in avg_grads:
assert var.name.startswith('tower'), var.name
my_name = '/'.join(var.name.split('/')[1:])
my_name = get_op_tensor_name(my_name)[0]
new_v = tf.get_variable(my_name, dtype=var.dtype.base_dtype,
initializer=var.initial_value,
trainable=True)
# (g, v) to be applied, where v is global (ps vars)
ps_var_grads.append((grad, new_v))
return ps_var_grads
@staticmethod
def _shadow_model_variables(shadow_vars):
"""
Create shadow vars for model_variables as well, and add to the list of ``shadow_vars``.
Returns:
list of (shadow_model_var, local_model_var) used for syncing.
"""
curr_shadow_vars = set([v.name for v in shadow_vars])
model_vars = tf.model_variables()
shadow_model_vars = []
for v in model_vars:
assert v.name.startswith('tower'), "Found some MODEL_VARIABLES created outside of the model!"
stripped_name = get_op_tensor_name(re.sub('tower[0-9]+/', '', v.name))[0]
if stripped_name in curr_shadow_vars:
continue
new_v = tf.get_variable(stripped_name, dtype=v.dtype.base_dtype,
initializer=v.initial_value,
trainable=False)
curr_shadow_vars.add(stripped_name) # avoid duplicated shadow_model_vars
shadow_vars.append(new_v)
shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower
return shadow_model_vars
def _apply_gradients_and_copy(self, raw_grad_list, ps_var_grads):
"""
Args:
raw_grad_list: Ngpu x Nvar x 2 gradient list from all towers
ps_var_grads: Nvar x 2 (grad, ps_var)
Returns:
list of copy ops
"""
# TODO do this for variables together?
opt = self.model.get_optimizer()
var_update_ops = []
for vid, (g, v) in enumerate(ps_var_grads):
apply_gradient_op = opt.apply_gradients([(g, v)])
barrier = self._add_sync_queues_and_barrier(
'param_update_barrier_{}'.format(vid), [apply_gradient_op])
with tf.control_dependencies([barrier]), \
tf.device(self.cpu_device):
updated_value = v.read_value()
for towerid in range(self.nr_gpu):
var_update_ops.append(
raw_grad_list[towerid][vid][1].assign(updated_value))
return var_update_ops
def _setup(self):
if self.job_name == 'ps':
logger.info("Running ps {}".format(self.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return #4713
return
with tf.device(self.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
# do this before super.setup because input_source my need global step
super(DistributedReplicatedTrainer, self)._setup()
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
devices=self.raw_devices,
var_strategy='replicated',
vs_names=None) # use the default vs names
avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedTrainer._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads)
self._shadow_vars = [v for (_, v) in ps_var_grads]
self._shadow_model_vars = DistributedReplicatedTrainer._shadow_model_variables(self._shadow_vars)
# TODO add options to synchronize less
main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self._add_sync_queues_and_barrier(
'post_copy_barrier', [main_fetch])
# initial local_vars syncing
cb = RunOp(self._get_initial_sync_op,
run_before=True, run_as_trigger=False, verbose=True)
cb.chief_only = False
self.register_callback(cb)
# model_variables syncing
if len(self._shadow_model_vars) and self.is_chief:
cb = RunOp(self._get_sync_model_vars_op,
run_before=False, run_as_trigger=True, verbose=True)
logger.warn("For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequenctly.")
self.register_callback(cb)
self._set_session_creator()
def _set_session_creator(self):
old_sess_creator = self.config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \
or self.config.session_config is not None:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server.")
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables()
sm = tf.train.SessionManager(
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
def _create_session():
if self.is_chief:
return sm.prepare_session(master=self.server.target, init_op=init_op)
else:
return sm.wait_for_session(master=self.server.target)
class _Creator(tf.train.SessionCreator):
def create_session(self):
return _create_session()
self.config.session_creator = _Creator()
def _add_sync_queues_and_barrier(self, name, dependencies):
"""Adds ops to enqueue on all worker queues.
Args:
name: prefixed for the shared_name of ops.
dependencies: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self.sync_queue_counter += 1
with tf.device(self.sync_queue_devices[self.sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [
tf.FIFOQueue(self.num_worker, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name, i))
for i in range(self.num_worker)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can finish this step.
token = tf.constant(False)
with tf.control_dependencies(dependencies):
for i, q in enumerate(sync_queues):
if i != self.task_index:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops.append(
sync_queues[self.task_index].dequeue_many(len(sync_queues) - 1))
return tf.group(*queue_ops, name=name)
def _get_initial_sync_op(self):
"""
Get the op to copy-initialized all local variables from PS.
"""
def strip_port(s):
if s.endswith(':0'):
return s[:-2]
return s
local_vars = tf.local_variables()
local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars])
ops = []
nr_shadow_vars = len(self._shadow_vars)
for v in self._shadow_vars:
vname = strip_port(v.name)
for i in range(self.nr_gpu):
name = 'tower%s/%s' % (i, vname)
assert name in local_var_by_name, \
"Shadow variable {} doesn't match a corresponding local variable!".format(v.name)
copy_to = local_var_by_name[name]
# logger.info("{} -> {}".format(v.name, copy_to.name))
ops.append(copy_to.assign(v.read_value()))
return tf.group(*ops, name='sync_{}_variables_from_ps'.format(nr_shadow_vars))
def _get_sync_model_vars_op(self):
"""
Get the op to sync local model_variables to PS.
"""
ops = []
for (shadow_v, local_v) in self._shadow_model_vars:
ops.append(shadow_v.assign(local_v.read_value()))
assert len(ops)
return tf.group(*ops, name='sync_{}_model_variables_to_ps'.format(len(ops)))
@property
def vs_name_for_predictor(self):
return "tower0"
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from six.moves import zip
from ..tfutils.tower import TowerContext, get_current_tower_context from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_source import QueueInput, FeedfreeInput from .input_source import QueueInput, FeedfreeInput
...@@ -64,20 +65,18 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -64,20 +65,18 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient""" """ get the cost and gradient"""
self.build_train_tower() self.build_train_tower()
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
# opt may be created under first-tower variable scope (which is '')
opt = self.model.get_optimizer()
# GATE_NONE faster?
varlist = tf.trainable_variables() varlist = tf.trainable_variables()
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables and ctx.vs_name: if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower # only optimize w.r.t vars in this tower
# TODO assumption on the first-tower empty variable scope # TODO use ctx.vars?
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')] varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = opt.compute_gradients( grads = tf.gradients(
cost, cost,
var_list=varlist, varlist,
gate_gradients=tf.train.Optimizer.GATE_NONE, gate_gradients=False,
colocate_gradients_with_ops=True) colocate_gradients_with_ops=True)
grads = list(zip(grads, varlist))
return cost, grads return cost, grads
......
...@@ -241,7 +241,9 @@ class QueueInput(FeedfreeInput): ...@@ -241,7 +241,9 @@ class QueueInput(FeedfreeInput):
def setup_training(self, trainer): def setup_training(self, trainer):
super(QueueInput, self).setup_training(trainer) super(QueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread)) cb = StartProcOrThread(self.thread)
cb.chief_only = False
trainer.register_callback(cb)
def get_input_tensors(self): def get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
...@@ -365,6 +367,7 @@ class DummyConstantInput(TensorInput): ...@@ -365,6 +367,7 @@ class DummyConstantInput(TensorInput):
def fn(): def fn():
tlist = [] tlist = []
ctx = get_current_tower_context() ctx = get_current_tower_context()
assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs) assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs): for idx, p in enumerate(self.input_placehdrs):
tlist.append(tf.get_variable( tlist.append(tf.get_variable(
......
...@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True): ...@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True):
class MultiGPUTrainerBase(Trainer): class MultiGPUTrainerBase(Trainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def build_on_multi_tower(towers, func, devices=None, var_strategy='shared'): def build_on_multi_tower(
towers, func,
devices=None, var_strategy='shared',
vs_names=None):
""" """
Args: Args:
towers: list of gpu relative ids towers: list of gpu relative ids
func: a lambda to be called inside each tower func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers. devices: a list of devices to be used. By default will use GPUs in towers.
var_strategy (str): var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
Returns: Returns:
List of outputs of ``func``, evaluated on each tower. List of outputs of ``func``, evaluated on each tower.
...@@ -70,15 +74,20 @@ class MultiGPUTrainerBase(Trainer): ...@@ -70,15 +74,20 @@ class MultiGPUTrainerBase(Trainer):
keys_to_freeze = TOWER_FREEZE_KEYS[:] keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly if var_strategy == 'replicated': # TODO ugly
logger.info("UPDATE_OPS from all GPUs will be kept in the collection.") logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS) keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
else:
assert vs_names is None
if vs_names is None:
vs_names = [None] * len(towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with TowerContext( with TowerContext(
'tower{}'.format(idx), 'tower{}'.format(idx),
device=device, is_training=True, device=device, is_training=True,
var_strategy=var_strategy): var_strategy=var_strategy,
vs_name=vs_names[idx]):
if idx == t: if idx == t:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
else: else:
...@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, self.config.tower,
lambda: self._get_cost_and_grad()[1], lambda: self._get_cost_and_grad()[1],
var_strategy='replicated') var_strategy='replicated',
# use no variable scope for the first tower
vs_names=[''] + [None] * (self.config.nr_tower - 1))
grads = self._allreduce_grads(grad_list) grads = self._allreduce_grads(grad_list)
train_ops = [] train_ops = []
...@@ -261,7 +272,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -261,7 +272,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
self.train_op = tf.group(*train_ops, name='train_op') self.train_op = tf.group(*train_ops, name='train_op')
self.register_callback(RunOp( self.register_callback(RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops, SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True)) run_before=True, run_as_trigger=True, verbose=True))
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
...@@ -279,7 +290,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -279,7 +290,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
split_name = split_name[1:] split_name = split_name[1:]
copy_from = var_by_name['/'.join(split_name)] copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value())) post_init_ops.append(v.assign(copy_from.read_value()))
return tf.group(*post_init_ops, name='init_sync_vars') return tf.group(*post_init_ops, name='sync_variables_from_tower0')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase, class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: predict.py # File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..predict import (OnlinePredictor, from ..predict import (OnlinePredictor,
PredictorTowerBuilder) PredictorTowerBuilder)
...@@ -19,6 +20,7 @@ class PredictorFactory(object): ...@@ -19,6 +20,7 @@ class PredictorFactory(object):
""" """
self.model = trainer.model self.model = trainer.model
self.towers = trainer.config.predict_tower self.towers = trainer.config.predict_tower
self.vs_name = trainer.vs_name_for_predictor
def fn(_): def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs()) self.model.build_graph(self.model.get_reused_placehdrs())
...@@ -34,6 +36,7 @@ class PredictorFactory(object): ...@@ -34,6 +36,7 @@ class PredictorFactory(object):
""" """
tower = self.towers[tower] tower = self.towers[tower]
# just ensure the tower exists. won't rebuild (memoized) # just ensure the tower exists. won't rebuild (memoized)
with tf.variable_scope(self.vs_name, reuse=True):
self._tower_builder.build(tower) self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()]) placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment