Commit f0573ed2 authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'distributed' (#144)

parents a3674b47 930481f2
......@@ -36,6 +36,8 @@ class Callback(object):
.. automethod:: _after_train
"""
_chief_only = True
def setup_graph(self, trainer):
self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer
......@@ -162,6 +164,19 @@ class Callback(object):
def local_step(self):
return self.trainer.local_step
@property
def chief_only(self):
"""
Only run this callback on chief training process.
Returns: bool
"""
return self._chief_only
@chief_only.setter
def chief_only(self, v):
self._chief_only = v
def __str__(self):
return type(self).__name__
......
......@@ -17,13 +17,15 @@ class RunOp(Callback):
""" Run an Op. """
def __init__(self, setup_func,
run_before=True, run_as_trigger=True, run_step=False):
run_before=True, run_as_trigger=True,
run_step=False, verbose=False):
"""
Args:
setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training
run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training)
verbose (bool): pring logs when the op is run.
Examples:
The `DQN Example
......@@ -34,27 +36,38 @@ class RunOp(Callback):
self.run_before = run_before
self.run_as_trigger = run_as_trigger
self.run_step = run_step
self.verbose = verbose
def _setup_graph(self):
self._op = self.setup_func()
def _before_train(self):
if self.run_before:
self._print()
self._op.run()
def _trigger(self):
if self.run_as_trigger:
self._print()
self._op.run()
def _before_run(self, _):
if self.run_step:
self._print()
return [self._op]
def _print(self):
if self.verbose:
logger.info("Running Op {} ...".format(self._op.name))
class RunUpdateOps(RunOp):
"""
Run ops from the collection UPDATE_OPS every step
"""
_chief_only = False
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
def f():
ops = tf.get_collection(collection)
......
......@@ -90,7 +90,8 @@ class InferenceRunnerBase(Callback):
def fn(_):
in_tensors = self._input_source.get_input_tensors()
self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._hooks = [self._build_hook(inf) for inf in self.infs]
......
......@@ -72,7 +72,7 @@ class GraphVarParam(HyperParam):
self.var = v
break
else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))
raise ValueError("{} is not a GLOBAL_VARIABLE in the graph!".format(self.var_name))
def set_value(self, v):
""" Assign the variable a new value. """
......
......@@ -43,6 +43,7 @@ class ModelSaver(Callback):
vars = []
for key in self.var_collections:
vars.extend(tf.get_collection(key))
vars = list(set(vars))
self.path = os.path.join(self.checkpoint_dir, 'model')
if get_tf_version_number() <= 1.1:
self.saver = tf.train.Saver(
......
......@@ -55,13 +55,14 @@ class MaintainStepCounter(Callback):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
self.gs_incr_var = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME)
with tf.device(gs_var.device):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
# tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_var)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
gs_val = get_global_step_value()
......@@ -81,6 +82,8 @@ class MaintainStepCounter(Callback):
class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """
_chief_only = False
def __init__(self, names=[]):
"""
Args:
......
......@@ -136,7 +136,7 @@ def layer_register(
# log shape info and add activation
logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs)))
_LAYER_LOGGED.add(scope.name)
_LAYER_LOGGED.add(scope_name)
else:
# run the actual function
outputs = func(*args, **actual_args)
......
......@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
for p in params:
para_name = p.name
# in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and (not para_name.startswith(ctx.vs_name)):
if ctx.has_own_variables and ctx.vs_name and (not para_name.startswith(ctx.vs_name)):
continue
if re.search(regex, para_name):
costs.append(func(p))
......
......@@ -39,9 +39,11 @@ def get_default_sess_config(mem_fraction=0.99):
conf.inter_op_parallelism_threads = 0
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
if get_tf_version_number() >= 1.2:
conf.gpu_options.force_gpu_compatible = True
conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True
# force gpu compatible?
conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return conf
......
......@@ -154,11 +154,13 @@ def add_moving_summary(v, *args, **kwargs):
for x in v:
assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape()
# TODO will produce tower0/xxx?
# TODO will produce variable tower0/xxx?
# TODO not saved under distributed
# TODO use zero_debias
with tf.name_scope(None):
gs = get_global_step_var()
with tf.name_scope(None), tf.device(gs.device):
averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA')
decay, num_updates=gs, name='EMA')
avg_maintain_op = averager.apply(v)
for c in v:
......
......@@ -17,13 +17,16 @@ class TowerContext(object):
def __init__(self, tower_name,
device=None, is_training=None,
var_strategy='shared'):
var_strategy='shared',
vs_name=None):
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
"""
self._name = tower_name
if device is None:
......@@ -38,6 +41,13 @@ class TowerContext(object):
self._var_strategy = var_strategy
if self._var_strategy == 'replicated':
assert self._name
if vs_name is None:
self._vs_name = self._name
else:
self._vs_name = vs_name
else:
assert vs_name is None, "vs_name is only valid in 'replicated' mode!"
self._vs_name = ''
@property
def is_main_training_tower(self):
......@@ -62,12 +72,7 @@ class TowerContext(object):
# variable_scope name
@property
def vs_name(self):
if self.has_own_variables:
# do not open new variable scope for the main tower,
# just use '', so that Saver & PredictTower know what to do
if self.index > 0:
return self._name
return ""
return self._vs_name
@property
def index(self):
......@@ -113,13 +118,16 @@ class TowerContext(object):
self._ctxs = []
if len(self._name):
if self.has_own_variables:
if self.vs_name:
if len(self.vs_name):
self._ctxs.append(tf.variable_scope(self.vs_name))
else:
# use existing variable scope
reuse = self.index > 0 or (not self.is_training)
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=reuse))
if self.is_training:
reuse = self.index > 0
if reuse is True:
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
# if not training, should handle vs outside (TODO not good)
self._ctxs.append(tf.name_scope(self._name))
self._ctxs.append(tf.device(self._device))
for c in self._ctxs:
......
......@@ -160,7 +160,7 @@ def get_checkpoint_path(model_path):
new_path = model_path.split('.index')[0]
if new_path != model_path:
logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
return model_path
......@@ -183,7 +183,8 @@ def dump_chkpt_vars(model_path):
def is_training_name(name):
"""
This is a hack temporarily used to improve logging. Do not use this function.
Guess if a name belongs to a training-only variables.
Only used internally to avoid too many logging. Do not use it.
Returns:
bool: Guess whether this tensor is something only used in training.
......
......@@ -9,8 +9,6 @@ import six
from six.moves import range
import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from .predict import PredictorFactory
from .config import TrainConfig
......@@ -21,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
__all__ = ['Trainer', 'StopTraining']
......@@ -46,6 +45,9 @@ class Trainer(object):
local_step (int): the number of steps that have finished in the current epoch.
global_step (int): the number of steps that have finished.
"""
# step attr only available after before_train?
is_chief = True
def __init__(self, config):
"""
......@@ -79,14 +81,20 @@ class Trainer(object):
assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!"
self._callbacks.append(cb)
if not self.is_chief and cb.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(cb)))
else:
self._callbacks.append(cb)
def register_monitor(self, mon):
assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
self.monitors.append(mon)
self.register_callback(mon)
if not self.is_chief and mon.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(mon)))
else:
self.monitors.append(mon)
self.register_callback(mon)
def train(self):
""" Start training """
......@@ -110,6 +118,7 @@ class Trainer(object):
self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors)
# TODO cache per graph, avoid describing all towers
describe_model()
# some final operations that might modify the graph
......@@ -117,21 +126,28 @@ class Trainer(object):
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
# create session
logger.info("Creating the session ...")
self.sess = self.config.session_creator.create_session()
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
self._create_session()
logger.info("Initializing the session ...")
# init session
self.config.session_init.init(self.sess)
if self.is_chief:
logger.info("Initializing the session ...")
self.config.session_init.init(self.sess)
else:
assert isinstance(self.config.session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!"
self.sess.graph.finalize()
logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
self.sess = self.config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@abstractmethod
def _setup(self):
......@@ -154,12 +170,14 @@ class Trainer(object):
self._starting_step = get_global_step_value()
try:
self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
self._starting_step = get_global_step_value()
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time()
for self.local_step in range(self.config.steps_per_epoch):
if self._monitored_sess.should_stop():
if self.hooked_sess.should_stop():
return
self.run_step() # implemented by subclass
self._callbacks.trigger_step()
......@@ -169,6 +187,7 @@ class Trainer(object):
# trigger epoch outside the timing region.
self._trigger_epoch()
self._callbacks.trigger_epoch()
logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.")
except KeyboardInterrupt:
......@@ -177,7 +196,14 @@ class Trainer(object):
raise
finally:
self._callbacks.after_train()
self._monitored_sess.close()
self.hooked_sess.close()
@property
def vs_name_for_predictor(self):
"""
The variable scope name a predictor should be built in.
"""
return ""
# Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: distributed.py
import tensorflow as tf
import re
import os
from six.moves import range
from ..utils import logger
from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name
__all__ = ['DistributedReplicatedTrainer']
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"""
Distributed replicated training.
Each worker process builds the same model on one or more GPUs.
Gradients across GPUs are averaged within the worker,
and get synchronously applied to the global copy of variables located on PS.
Then each worker copy the latest variables from PS back to local.
Note:
Gradients are not averaged across workers.
"""
def __init__(self, config, server):
"""
Args:
config (TrainConfig): the train config.
server (tf.train.Server): the server object with ps and workers
"""
self.server = server
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.job_name = server_def.job_name
self.task_index = server_def.task_index
assert self.job_name in ['ps', 'worker'], self.job_name
assert tf.test.is_gpu_available
self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
super(DistributedReplicatedTrainer, self).__init__(config)
worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter(
worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
self.num_ps = self.cluster.num_tasks('ps')
self.num_worker = self.cluster.num_tasks('worker')
self.nr_gpu = config.nr_tower
self.cpu_device = '%s/cpu:0' % worker_prefix
self.raw_devices = ['%s/%s:%i' % (worker_prefix, 'gpu', i) for i in range(self.nr_gpu)]
# Device for queues for managing synchronization between servers
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0
@staticmethod
def _average_grads(tower_grads, devices):
"""
Average grad with round-robin device selection.
Args:
tower_grads: Ngpu x Nvar x 2
Returns:
Nvar x 2
"""
nr_device = len(devices)
if nr_device == 1:
return tower_grads[0]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for i, grad_and_vars in enumerate(zip(*tower_grads)):
# Ngpu * 2
with tf.device(devices[i % nr_device]):
v = grad_and_vars[0][1]
# average gradient
all_grads = [g for (g, _) in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
continue
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_device)
new_tower_grads.append((grad, v))
return new_tower_grads
@staticmethod
def _apply_shadow_vars(avg_grads):
"""
Replace variables in avg_grads by shadow variables.
"""
ps_var_grads = []
for grad, var in avg_grads:
assert var.name.startswith('tower'), var.name
my_name = '/'.join(var.name.split('/')[1:])
my_name = get_op_tensor_name(my_name)[0]
new_v = tf.get_variable(my_name, dtype=var.dtype.base_dtype,
initializer=var.initial_value,
trainable=True)
# (g, v) to be applied, where v is global (ps vars)
ps_var_grads.append((grad, new_v))
return ps_var_grads
@staticmethod
def _shadow_model_variables(shadow_vars):
"""
Create shadow vars for model_variables as well, and add to the list of ``shadow_vars``.
Returns:
list of (shadow_model_var, local_model_var) used for syncing.
"""
curr_shadow_vars = set([v.name for v in shadow_vars])
model_vars = tf.model_variables()
shadow_model_vars = []
for v in model_vars:
assert v.name.startswith('tower'), "Found some MODEL_VARIABLES created outside of the model!"
stripped_name = get_op_tensor_name(re.sub('tower[0-9]+/', '', v.name))[0]
if stripped_name in curr_shadow_vars:
continue
new_v = tf.get_variable(stripped_name, dtype=v.dtype.base_dtype,
initializer=v.initial_value,
trainable=False)
curr_shadow_vars.add(stripped_name) # avoid duplicated shadow_model_vars
shadow_vars.append(new_v)
shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower
return shadow_model_vars
def _apply_gradients_and_copy(self, raw_grad_list, ps_var_grads):
"""
Args:
raw_grad_list: Ngpu x Nvar x 2 gradient list from all towers
ps_var_grads: Nvar x 2 (grad, ps_var)
Returns:
list of copy ops
"""
# TODO do this for variables together?
opt = self.model.get_optimizer()
var_update_ops = []
for vid, (g, v) in enumerate(ps_var_grads):
apply_gradient_op = opt.apply_gradients([(g, v)])
barrier = self._add_sync_queues_and_barrier(
'param_update_barrier_{}'.format(vid), [apply_gradient_op])
with tf.control_dependencies([barrier]), \
tf.device(self.cpu_device):
updated_value = v.read_value()
for towerid in range(self.nr_gpu):
var_update_ops.append(
raw_grad_list[towerid][vid][1].assign(updated_value))
return var_update_ops
def _setup(self):
if self.job_name == 'ps':
logger.info("Running ps {}".format(self.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return #4713
return
with tf.device(self.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
# do this before super.setup because input_source my need global step
super(DistributedReplicatedTrainer, self)._setup()
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
devices=self.raw_devices,
var_strategy='replicated',
vs_names=None) # use the default vs names
avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedTrainer._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads)
self._shadow_vars = [v for (_, v) in ps_var_grads]
self._shadow_model_vars = DistributedReplicatedTrainer._shadow_model_variables(self._shadow_vars)
# TODO add options to synchronize less
main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self._add_sync_queues_and_barrier(
'post_copy_barrier', [main_fetch])
# initial local_vars syncing
cb = RunOp(self._get_initial_sync_op,
run_before=True, run_as_trigger=False, verbose=True)
cb.chief_only = False
self.register_callback(cb)
# model_variables syncing
if len(self._shadow_model_vars) and self.is_chief:
cb = RunOp(self._get_sync_model_vars_op,
run_before=False, run_as_trigger=True, verbose=True)
logger.warn("For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequenctly.")
self.register_callback(cb)
self._set_session_creator()
def _set_session_creator(self):
old_sess_creator = self.config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \
or self.config.session_config is not None:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server.")
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables()
sm = tf.train.SessionManager(
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
def _create_session():
if self.is_chief:
return sm.prepare_session(master=self.server.target, init_op=init_op)
else:
return sm.wait_for_session(master=self.server.target)
class _Creator(tf.train.SessionCreator):
def create_session(self):
return _create_session()
self.config.session_creator = _Creator()
def _add_sync_queues_and_barrier(self, name, dependencies):
"""Adds ops to enqueue on all worker queues.
Args:
name: prefixed for the shared_name of ops.
dependencies: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self.sync_queue_counter += 1
with tf.device(self.sync_queue_devices[self.sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [
tf.FIFOQueue(self.num_worker, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name, i))
for i in range(self.num_worker)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can finish this step.
token = tf.constant(False)
with tf.control_dependencies(dependencies):
for i, q in enumerate(sync_queues):
if i != self.task_index:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops.append(
sync_queues[self.task_index].dequeue_many(len(sync_queues) - 1))
return tf.group(*queue_ops, name=name)
def _get_initial_sync_op(self):
"""
Get the op to copy-initialized all local variables from PS.
"""
def strip_port(s):
if s.endswith(':0'):
return s[:-2]
return s
local_vars = tf.local_variables()
local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars])
ops = []
nr_shadow_vars = len(self._shadow_vars)
for v in self._shadow_vars:
vname = strip_port(v.name)
for i in range(self.nr_gpu):
name = 'tower%s/%s' % (i, vname)
assert name in local_var_by_name, \
"Shadow variable {} doesn't match a corresponding local variable!".format(v.name)
copy_to = local_var_by_name[name]
# logger.info("{} -> {}".format(v.name, copy_to.name))
ops.append(copy_to.assign(v.read_value()))
return tf.group(*ops, name='sync_{}_variables_from_ps'.format(nr_shadow_vars))
def _get_sync_model_vars_op(self):
"""
Get the op to sync local model_variables to PS.
"""
ops = []
for (shadow_v, local_v) in self._shadow_model_vars:
ops.append(shadow_v.assign(local_v.read_value()))
assert len(ops)
return tf.group(*ops, name='sync_{}_model_variables_to_ps'.format(len(ops)))
@property
def vs_name_for_predictor(self):
return "tower0"
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from six.moves import zip
from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_source import QueueInput, FeedfreeInput
......@@ -64,20 +65,18 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient"""
self.build_train_tower()
cost = self.model.get_cost() # assume single cost
# opt may be created under first-tower variable scope (which is '')
opt = self.model.get_optimizer()
# GATE_NONE faster?
varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower
# TODO assumption on the first-tower empty variable scope
# TODO use ctx.vars?
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = opt.compute_gradients(
grads = tf.gradients(
cost,
var_list=varlist,
gate_gradients=tf.train.Optimizer.GATE_NONE,
varlist,
gate_gradients=False,
colocate_gradients_with_ops=True)
grads = list(zip(grads, varlist))
return cost, grads
......
......@@ -241,7 +241,9 @@ class QueueInput(FeedfreeInput):
def setup_training(self, trainer):
super(QueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread))
cb = StartProcOrThread(self.thread)
cb.chief_only = False
trainer.register_callback(cb)
def get_input_tensors(self):
with tf.device('/cpu:0'):
......@@ -365,6 +367,7 @@ class DummyConstantInput(TensorInput):
def fn():
tlist = []
ctx = get_current_tower_context()
assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs):
tlist.append(tf.get_variable(
......
......@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True):
class MultiGPUTrainerBase(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
def build_on_multi_tower(towers, func, devices=None, var_strategy='shared'):
def build_on_multi_tower(
towers, func,
devices=None, var_strategy='shared',
vs_names=None):
"""
Args:
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers.
var_strategy (str):
var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
Returns:
List of outputs of ``func``, evaluated on each tower.
......@@ -70,15 +74,20 @@ class MultiGPUTrainerBase(Trainer):
keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly
logger.info("UPDATE_OPS from all GPUs will be kept in the collection.")
logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
else:
assert vs_names is None
if vs_names is None:
vs_names = [None] * len(towers)
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with TowerContext(
'tower{}'.format(idx),
device=device, is_training=True,
var_strategy=var_strategy):
var_strategy=var_strategy,
vs_name=vs_names[idx]):
if idx == t:
logger.info("Building graph for training tower {}...".format(idx))
else:
......@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
var_strategy='replicated')
var_strategy='replicated',
# use no variable scope for the first tower
vs_names=[''] + [None] * (self.config.nr_tower - 1))
grads = self._allreduce_grads(grad_list)
train_ops = []
......@@ -261,7 +272,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
self.train_op = tf.group(*train_ops, name='train_op')
self.register_callback(RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True))
run_before=True, run_as_trigger=True, verbose=True))
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
......@@ -279,7 +290,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
split_name = split_name[1:]
copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value()))
return tf.group(*post_init_ops, name='init_sync_vars')
return tf.group(*post_init_ops, name='sync_variables_from_tower0')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
......
......@@ -3,6 +3,7 @@
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..predict import (OnlinePredictor,
PredictorTowerBuilder)
......@@ -19,6 +20,7 @@ class PredictorFactory(object):
"""
self.model = trainer.model
self.towers = trainer.config.predict_tower
self.vs_name = trainer.vs_name_for_predictor
def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs())
......@@ -34,7 +36,8 @@ class PredictorFactory(object):
"""
tower = self.towers[tower]
# just ensure the tower exists. won't rebuild (memoized)
self._tower_builder.build(tower)
with tf.variable_scope(self.vs_name, reuse=True):
self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment