Commit f0573ed2 authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'distributed' (#144)

parents a3674b47 930481f2
...@@ -36,6 +36,8 @@ class Callback(object): ...@@ -36,6 +36,8 @@ class Callback(object):
.. automethod:: _after_train .. automethod:: _after_train
""" """
_chief_only = True
def setup_graph(self, trainer): def setup_graph(self, trainer):
self._steps_per_epoch = trainer.config.steps_per_epoch self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer self.trainer = trainer
...@@ -162,6 +164,19 @@ class Callback(object): ...@@ -162,6 +164,19 @@ class Callback(object):
def local_step(self): def local_step(self):
return self.trainer.local_step return self.trainer.local_step
@property
def chief_only(self):
"""
Only run this callback on chief training process.
Returns: bool
"""
return self._chief_only
@chief_only.setter
def chief_only(self, v):
self._chief_only = v
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
......
...@@ -17,13 +17,15 @@ class RunOp(Callback): ...@@ -17,13 +17,15 @@ class RunOp(Callback):
""" Run an Op. """ """ Run an Op. """
def __init__(self, setup_func, def __init__(self, setup_func,
run_before=True, run_as_trigger=True, run_step=False): run_before=True, run_as_trigger=True,
run_step=False, verbose=False):
""" """
Args: Args:
setup_func: a function that returns the Op in the graph setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training run_before (bool): run the Op before training
run_as_trigger (bool): run the Op on every trigger run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training) run_step (bool): run the Op every step (along with training)
verbose (bool): pring logs when the op is run.
Examples: Examples:
The `DQN Example The `DQN Example
...@@ -34,27 +36,38 @@ class RunOp(Callback): ...@@ -34,27 +36,38 @@ class RunOp(Callback):
self.run_before = run_before self.run_before = run_before
self.run_as_trigger = run_as_trigger self.run_as_trigger = run_as_trigger
self.run_step = run_step self.run_step = run_step
self.verbose = verbose
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
def _before_train(self): def _before_train(self):
if self.run_before: if self.run_before:
self._print()
self._op.run() self._op.run()
def _trigger(self): def _trigger(self):
if self.run_as_trigger: if self.run_as_trigger:
self._print()
self._op.run() self._op.run()
def _before_run(self, _): def _before_run(self, _):
if self.run_step: if self.run_step:
self._print()
return [self._op] return [self._op]
def _print(self):
if self.verbose:
logger.info("Running Op {} ...".format(self._op.name))
class RunUpdateOps(RunOp): class RunUpdateOps(RunOp):
""" """
Run ops from the collection UPDATE_OPS every step Run ops from the collection UPDATE_OPS every step
""" """
_chief_only = False
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS): def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
def f(): def f():
ops = tf.get_collection(collection) ops = tf.get_collection(collection)
......
...@@ -90,7 +90,8 @@ class InferenceRunnerBase(Callback): ...@@ -90,7 +90,8 @@ class InferenceRunnerBase(Callback):
def fn(_): def fn(_):
in_tensors = self._input_source.get_input_tensors() in_tensors = self._input_source.get_input_tensors()
self.trainer.model.build_graph(in_tensors) self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id) with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
......
...@@ -72,7 +72,7 @@ class GraphVarParam(HyperParam): ...@@ -72,7 +72,7 @@ class GraphVarParam(HyperParam):
self.var = v self.var = v
break break
else: else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name)) raise ValueError("{} is not a GLOBAL_VARIABLE in the graph!".format(self.var_name))
def set_value(self, v): def set_value(self, v):
""" Assign the variable a new value. """ """ Assign the variable a new value. """
......
...@@ -43,6 +43,7 @@ class ModelSaver(Callback): ...@@ -43,6 +43,7 @@ class ModelSaver(Callback):
vars = [] vars = []
for key in self.var_collections: for key in self.var_collections:
vars.extend(tf.get_collection(key)) vars.extend(tf.get_collection(key))
vars = list(set(vars))
self.path = os.path.join(self.checkpoint_dir, 'model') self.path = os.path.join(self.checkpoint_dir, 'model')
if get_tf_version_number() <= 1.1: if get_tf_version_number() <= 1.1:
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
......
...@@ -55,13 +55,14 @@ class MaintainStepCounter(Callback): ...@@ -55,13 +55,14 @@ class MaintainStepCounter(Callback):
# ensure it exists # ensure it exists
gs_var = get_global_step_var() gs_var = get_global_step_var()
with tf.name_scope(None): with tf.name_scope(None):
self.gs_incr_var = tf.assign_add( with tf.device(gs_var.device):
gs_var, 1, self.gs_incr_op = tf.assign_add(
name=GLOBAL_STEP_INCR_OP_NAME) gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
# tf.mod( # tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch, # self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME) # name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_var) self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self): def _before_train(self):
gs_val = get_global_step_value() gs_val = get_global_step_value()
...@@ -81,6 +82,8 @@ class MaintainStepCounter(Callback): ...@@ -81,6 +82,8 @@ class MaintainStepCounter(Callback):
class ProgressBar(Callback): class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """ """ A progress bar based on tqdm. Enabled by default. """
_chief_only = False
def __init__(self, names=[]): def __init__(self, names=[]):
""" """
Args: Args:
......
...@@ -136,7 +136,7 @@ def layer_register( ...@@ -136,7 +136,7 @@ def layer_register(
# log shape info and add activation # log shape info and add activation
logger.info("{} output: {}".format( logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs))) scope.name, get_shape_str(outputs)))
_LAYER_LOGGED.add(scope.name) _LAYER_LOGGED.add(scope_name)
else: else:
# run the actual function # run the actual function
outputs = func(*args, **actual_args) outputs = func(*args, **actual_args)
......
...@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
for p in params: for p in params:
para_name = p.name para_name = p.name
# in replicated mode, only regularize variables inside this tower # in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and (not para_name.startswith(ctx.vs_name)): if ctx.has_own_variables and ctx.vs_name and (not para_name.startswith(ctx.vs_name)):
continue continue
if re.search(regex, para_name): if re.search(regex, para_name):
costs.append(func(p)) costs.append(func(p))
......
...@@ -39,9 +39,11 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -39,9 +39,11 @@ def get_default_sess_config(mem_fraction=0.99):
conf.inter_op_parallelism_threads = 0 conf.inter_op_parallelism_threads = 0
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
if get_tf_version_number() >= 1.2:
conf.gpu_options.force_gpu_compatible = True
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True conf.gpu_options.allow_growth = True
# force gpu compatible?
conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return conf return conf
......
...@@ -154,11 +154,13 @@ def add_moving_summary(v, *args, **kwargs): ...@@ -154,11 +154,13 @@ def add_moving_summary(v, *args, **kwargs):
for x in v: for x in v:
assert isinstance(x, tf.Tensor), x assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape() assert x.get_shape().ndims == 0, x.get_shape()
# TODO will produce tower0/xxx? # TODO will produce variable tower0/xxx?
# TODO not saved under distributed
# TODO use zero_debias # TODO use zero_debias
with tf.name_scope(None): gs = get_global_step_var()
with tf.name_scope(None), tf.device(gs.device):
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA') decay, num_updates=gs, name='EMA')
avg_maintain_op = averager.apply(v) avg_maintain_op = averager.apply(v)
for c in v: for c in v:
......
...@@ -17,13 +17,16 @@ class TowerContext(object): ...@@ -17,13 +17,16 @@ class TowerContext(object):
def __init__(self, tower_name, def __init__(self, tower_name,
device=None, is_training=None, device=None, is_training=None,
var_strategy='shared'): var_strategy='shared',
vs_name=None):
""" """
Args: Args:
tower_name (str): 'tower0', 'towerp0', or '' tower_name (str): 'tower0', 'towerp0', or ''
device (str or device function): the device to use. Defaults to either cpu0 or gpu0. device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
var_strategy (str): either 'shared' or 'replicated'. var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
""" """
self._name = tower_name self._name = tower_name
if device is None: if device is None:
...@@ -38,6 +41,13 @@ class TowerContext(object): ...@@ -38,6 +41,13 @@ class TowerContext(object):
self._var_strategy = var_strategy self._var_strategy = var_strategy
if self._var_strategy == 'replicated': if self._var_strategy == 'replicated':
assert self._name assert self._name
if vs_name is None:
self._vs_name = self._name
else:
self._vs_name = vs_name
else:
assert vs_name is None, "vs_name is only valid in 'replicated' mode!"
self._vs_name = ''
@property @property
def is_main_training_tower(self): def is_main_training_tower(self):
...@@ -62,12 +72,7 @@ class TowerContext(object): ...@@ -62,12 +72,7 @@ class TowerContext(object):
# variable_scope name # variable_scope name
@property @property
def vs_name(self): def vs_name(self):
if self.has_own_variables: return self._vs_name
# do not open new variable scope for the main tower,
# just use '', so that Saver & PredictTower know what to do
if self.index > 0:
return self._name
return ""
@property @property
def index(self): def index(self):
...@@ -113,13 +118,16 @@ class TowerContext(object): ...@@ -113,13 +118,16 @@ class TowerContext(object):
self._ctxs = [] self._ctxs = []
if len(self._name): if len(self._name):
if self.has_own_variables: if self.has_own_variables:
if self.vs_name: if len(self.vs_name):
self._ctxs.append(tf.variable_scope(self.vs_name)) self._ctxs.append(tf.variable_scope(self.vs_name))
else: else:
# use existing variable scope if self.is_training:
reuse = self.index > 0 or (not self.is_training) reuse = self.index > 0
self._ctxs.append(tf.variable_scope( if reuse is True:
tf.get_variable_scope(), reuse=reuse)) self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
# if not training, should handle vs outside (TODO not good)
self._ctxs.append(tf.name_scope(self._name)) self._ctxs.append(tf.name_scope(self._name))
self._ctxs.append(tf.device(self._device)) self._ctxs.append(tf.device(self._device))
for c in self._ctxs: for c in self._ctxs:
......
...@@ -160,7 +160,7 @@ def get_checkpoint_path(model_path): ...@@ -160,7 +160,7 @@ def get_checkpoint_path(model_path):
new_path = model_path.split('.index')[0] new_path = model_path.split('.index')[0]
if new_path != model_path: if new_path != model_path:
logger.warn( logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path)) "Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
return model_path return model_path
...@@ -183,7 +183,8 @@ def dump_chkpt_vars(model_path): ...@@ -183,7 +183,8 @@ def dump_chkpt_vars(model_path):
def is_training_name(name): def is_training_name(name):
""" """
This is a hack temporarily used to improve logging. Do not use this function. Guess if a name belongs to a training-only variables.
Only used internally to avoid too many logging. Do not use it.
Returns: Returns:
bool: Guess whether this tensor is something only used in training. bool: Guess whether this tensor is something only used in training.
......
...@@ -9,8 +9,6 @@ import six ...@@ -9,8 +9,6 @@ import six
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from .predict import PredictorFactory from .predict import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
...@@ -21,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor ...@@ -21,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
__all__ = ['Trainer', 'StopTraining'] __all__ = ['Trainer', 'StopTraining']
...@@ -46,6 +45,9 @@ class Trainer(object): ...@@ -46,6 +45,9 @@ class Trainer(object):
local_step (int): the number of steps that have finished in the current epoch. local_step (int): the number of steps that have finished in the current epoch.
global_step (int): the number of steps that have finished. global_step (int): the number of steps that have finished.
""" """
# step attr only available after before_train?
is_chief = True
def __init__(self, config): def __init__(self, config):
""" """
...@@ -79,14 +81,20 @@ class Trainer(object): ...@@ -79,14 +81,20 @@ class Trainer(object):
assert isinstance(cb, Callback), cb assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \ assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!" "Cannot register more callbacks after trainer was setup!"
self._callbacks.append(cb) if not self.is_chief and cb.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(cb)))
else:
self._callbacks.append(cb)
def register_monitor(self, mon): def register_monitor(self, mon):
assert isinstance(mon, TrainingMonitor), mon assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self.monitors, Monitors), \ assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!" "Cannot register more monitors after trainer was setup!"
self.monitors.append(mon) if not self.is_chief and mon.chief_only:
self.register_callback(mon) logger.warn("Callback {} is chief-only, skipped.".format(str(mon)))
else:
self.monitors.append(mon)
self.register_callback(mon)
def train(self): def train(self):
""" Start training """ """ Start training """
...@@ -110,6 +118,7 @@ class Trainer(object): ...@@ -110,6 +118,7 @@ class Trainer(object):
self.monitors = Monitors(self.monitors) self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors) self.register_callback(self.monitors)
# TODO cache per graph, avoid describing all towers
describe_model() describe_model()
# some final operations that might modify the graph # some final operations that might modify the graph
...@@ -117,21 +126,28 @@ class Trainer(object): ...@@ -117,21 +126,28 @@ class Trainer(object):
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
# create session
logger.info("Creating the session ...") logger.info("Creating the session ...")
self.sess = self.config.session_creator.create_session() self._create_session()
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
logger.info("Initializing the session ...") if self.is_chief:
# init session logger.info("Initializing the session ...")
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
else:
assert isinstance(self.config.session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!"
self.sess.graph.finalize() self.sess.graph.finalize()
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks() hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks) self.sess = self.config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@abstractmethod @abstractmethod
def _setup(self): def _setup(self):
...@@ -154,12 +170,14 @@ class Trainer(object): ...@@ -154,12 +170,14 @@ class Trainer(object):
self._starting_step = get_global_step_value() self._starting_step = get_global_step_value()
try: try:
self._callbacks.before_train() self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
self._starting_step = get_global_step_value()
for self.epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time() start_time = time.time()
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
if self._monitored_sess.should_stop(): if self.hooked_sess.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
self._callbacks.trigger_step() self._callbacks.trigger_step()
...@@ -169,6 +187,7 @@ class Trainer(object): ...@@ -169,6 +187,7 @@ class Trainer(object):
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._trigger_epoch() self._trigger_epoch()
self._callbacks.trigger_epoch() self._callbacks.trigger_epoch()
logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError): except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.") logger.info("Training was stopped.")
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -177,7 +196,14 @@ class Trainer(object): ...@@ -177,7 +196,14 @@ class Trainer(object):
raise raise
finally: finally:
self._callbacks.after_train() self._callbacks.after_train()
self._monitored_sess.close() self.hooked_sess.close()
@property
def vs_name_for_predictor(self):
"""
The variable scope name a predictor should be built in.
"""
return ""
# Predictor related methods: TODO # Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0): def get_predictor(self, input_names, output_names, tower=0):
......
This diff is collapsed.
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from six.moves import zip
from ..tfutils.tower import TowerContext, get_current_tower_context from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_source import QueueInput, FeedfreeInput from .input_source import QueueInput, FeedfreeInput
...@@ -64,20 +65,18 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -64,20 +65,18 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient""" """ get the cost and gradient"""
self.build_train_tower() self.build_train_tower()
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
# opt may be created under first-tower variable scope (which is '')
opt = self.model.get_optimizer()
# GATE_NONE faster?
varlist = tf.trainable_variables() varlist = tf.trainable_variables()
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables and ctx.vs_name: if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower # only optimize w.r.t vars in this tower
# TODO assumption on the first-tower empty variable scope # TODO use ctx.vars?
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')] varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = opt.compute_gradients( grads = tf.gradients(
cost, cost,
var_list=varlist, varlist,
gate_gradients=tf.train.Optimizer.GATE_NONE, gate_gradients=False,
colocate_gradients_with_ops=True) colocate_gradients_with_ops=True)
grads = list(zip(grads, varlist))
return cost, grads return cost, grads
......
...@@ -241,7 +241,9 @@ class QueueInput(FeedfreeInput): ...@@ -241,7 +241,9 @@ class QueueInput(FeedfreeInput):
def setup_training(self, trainer): def setup_training(self, trainer):
super(QueueInput, self).setup_training(trainer) super(QueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread)) cb = StartProcOrThread(self.thread)
cb.chief_only = False
trainer.register_callback(cb)
def get_input_tensors(self): def get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
...@@ -365,6 +367,7 @@ class DummyConstantInput(TensorInput): ...@@ -365,6 +367,7 @@ class DummyConstantInput(TensorInput):
def fn(): def fn():
tlist = [] tlist = []
ctx = get_current_tower_context() ctx = get_current_tower_context()
assert ctx is not None
assert len(self.shapes) == len(self.input_placehdrs) assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs): for idx, p in enumerate(self.input_placehdrs):
tlist.append(tf.get_variable( tlist.append(tf.get_variable(
......
...@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True): ...@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True):
class MultiGPUTrainerBase(Trainer): class MultiGPUTrainerBase(Trainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def build_on_multi_tower(towers, func, devices=None, var_strategy='shared'): def build_on_multi_tower(
towers, func,
devices=None, var_strategy='shared',
vs_names=None):
""" """
Args: Args:
towers: list of gpu relative ids towers: list of gpu relative ids
func: a lambda to be called inside each tower func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers. devices: a list of devices to be used. By default will use GPUs in towers.
var_strategy (str): var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
Returns: Returns:
List of outputs of ``func``, evaluated on each tower. List of outputs of ``func``, evaluated on each tower.
...@@ -70,15 +74,20 @@ class MultiGPUTrainerBase(Trainer): ...@@ -70,15 +74,20 @@ class MultiGPUTrainerBase(Trainer):
keys_to_freeze = TOWER_FREEZE_KEYS[:] keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly if var_strategy == 'replicated': # TODO ugly
logger.info("UPDATE_OPS from all GPUs will be kept in the collection.") logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS) keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
else:
assert vs_names is None
if vs_names is None:
vs_names = [None] * len(towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with TowerContext( with TowerContext(
'tower{}'.format(idx), 'tower{}'.format(idx),
device=device, is_training=True, device=device, is_training=True,
var_strategy=var_strategy): var_strategy=var_strategy,
vs_name=vs_names[idx]):
if idx == t: if idx == t:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
else: else:
...@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, self.config.tower,
lambda: self._get_cost_and_grad()[1], lambda: self._get_cost_and_grad()[1],
var_strategy='replicated') var_strategy='replicated',
# use no variable scope for the first tower
vs_names=[''] + [None] * (self.config.nr_tower - 1))
grads = self._allreduce_grads(grad_list) grads = self._allreduce_grads(grad_list)
train_ops = [] train_ops = []
...@@ -261,7 +272,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -261,7 +272,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
self.train_op = tf.group(*train_ops, name='train_op') self.train_op = tf.group(*train_ops, name='train_op')
self.register_callback(RunOp( self.register_callback(RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops, SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True)) run_before=True, run_as_trigger=True, verbose=True))
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
...@@ -279,7 +290,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -279,7 +290,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
split_name = split_name[1:] split_name = split_name[1:]
copy_from = var_by_name['/'.join(split_name)] copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value())) post_init_ops.append(v.assign(copy_from.read_value()))
return tf.group(*post_init_ops, name='init_sync_vars') return tf.group(*post_init_ops, name='sync_variables_from_tower0')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase, class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: predict.py # File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..predict import (OnlinePredictor, from ..predict import (OnlinePredictor,
PredictorTowerBuilder) PredictorTowerBuilder)
...@@ -19,6 +20,7 @@ class PredictorFactory(object): ...@@ -19,6 +20,7 @@ class PredictorFactory(object):
""" """
self.model = trainer.model self.model = trainer.model
self.towers = trainer.config.predict_tower self.towers = trainer.config.predict_tower
self.vs_name = trainer.vs_name_for_predictor
def fn(_): def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs()) self.model.build_graph(self.model.get_reused_placehdrs())
...@@ -34,7 +36,8 @@ class PredictorFactory(object): ...@@ -34,7 +36,8 @@ class PredictorFactory(object):
""" """
tower = self.towers[tower] tower = self.towers[tower]
# just ensure the tower exists. won't rebuild (memoized) # just ensure the tower exists. won't rebuild (memoized)
self._tower_builder.build(tower) with tf.variable_scope(self.vs_name, reuse=True):
self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()]) placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment