Commit 3ab6d2b0 authored by Yuxin Wu's avatar Yuxin Wu

variable scope issues for model saving / predictor

parent 95cb6ba2
......@@ -90,7 +90,8 @@ class InferenceRunnerBase(Callback):
def fn(_):
in_tensors = self._input_source.get_input_tensors()
self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._hooks = [self._build_hook(inf) for inf in self.infs]
......
......@@ -47,7 +47,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
for p in params:
para_name = p.name
# in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and (not para_name.startswith(ctx.vs_name)):
if ctx.has_own_variables and ctx.vs_name and (not para_name.startswith(ctx.vs_name)):
continue
if re.search(regex, para_name):
costs.append(func(p))
......
......@@ -17,13 +17,16 @@ class TowerContext(object):
def __init__(self, tower_name,
device=None, is_training=None,
var_strategy='shared'):
var_strategy='shared',
vs_name=None):
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
"""
self._name = tower_name
if device is None:
......@@ -38,6 +41,11 @@ class TowerContext(object):
self._var_strategy = var_strategy
if self._var_strategy == 'replicated':
assert self._name
if vs_name is None:
self._vs_name = self._name
else:
assert vs_name is None, "vs_name is only valid in 'replicated' mode!"
self._vs_name = ''
@property
def is_main_training_tower(self):
......@@ -62,12 +70,7 @@ class TowerContext(object):
# variable_scope name
@property
def vs_name(self):
if self.has_own_variables:
# do not open new variable scope for the main tower,
# just use '', so that Saver & PredictTower know what to do
if self.index > 0:
return self._name
return ""
return self._vs_name
@property
def index(self):
......@@ -113,13 +116,16 @@ class TowerContext(object):
self._ctxs = []
if len(self._name):
if self.has_own_variables:
if self.vs_name:
if len(self.vs_name):
self._ctxs.append(tf.variable_scope(self.vs_name))
else:
# use existing variable scope
reuse = self.index > 0 or (not self.is_training)
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=reuse))
if self.is_training:
reuse = self.index > 0
if reuse is True:
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
# if not training, should handle vs outside (TODO not good)
self._ctxs.append(tf.name_scope(self._name))
self._ctxs.append(tf.device(self._device))
for c in self._ctxs:
......
......@@ -19,6 +19,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
__all__ = ['Trainer', 'StopTraining']
......@@ -44,6 +45,7 @@ class Trainer(object):
local_step (int): the number of steps that have finished in the current epoch.
global_step (int): the number of steps that have finished.
"""
# step attr only available after before_train?
is_chief = True
......@@ -124,11 +126,19 @@ class Trainer(object):
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
if self.is_chief:
self.config.session_init._setup_graph()
# This might finalize the graph (in distributed)
logger.info("Creating the session ...")
self._create_session()
logger.info("Initializing the session ...")
self.config.session_init.init(self.sess)
if self.is_chief:
logger.info("Initializing the session ...")
self.config.session_init._run_init(self.sess)
else:
assert isinstance(self.config.session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!"
self.sess.graph.finalize()
logger.info("Graph Finalized.")
......@@ -164,6 +174,8 @@ class Trainer(object):
self._starting_step = get_global_step_value()
try:
self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
self._starting_step = get_global_step_value()
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num))
......@@ -190,6 +202,13 @@ class Trainer(object):
self._callbacks.after_train()
self.hooked_sess.close()
@property
def vs_name_for_predictor(self):
"""
The variable scope name a predictor should be built in.
"""
return ""
# Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0):
"""
......
......@@ -14,18 +14,15 @@ from ..tfutils.common import get_global_step_var, get_op_tensor_name
__all__ = ['DistributedReplicatedTrainer']
# Note that only trainable vars are shadowed
PS_SHADOW_VAR_PREFIX = 'ps_var'
# TODO only trainable model vars are saved
class OverrideToLocalVariableIfNotPsVar(object):
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if name.startswith(PS_SHADOW_VAR_PREFIX):
return getter(*args, **kwargs)
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
......@@ -103,7 +100,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"""
ps_var_grads = []
for grad, var in avg_grads:
my_name = PS_SHADOW_VAR_PREFIX + '/' + var.name
assert var.name.startswith('tower'), var.name
my_name = '/'.join(var.name.split('/')[1:])
my_name = get_op_tensor_name(my_name)[0]
new_v = tf.get_variable(my_name, dtype=var.dtype.base_dtype,
initializer=var.initial_value,
......@@ -141,26 +139,29 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
logger.info("Running ps {}".format(self.task_index))
self.server.join()
return # TODO exit and skip mainloop how?
super(DistributedReplicatedTrainer, self)._setup()
with tf.device(self.param_server_device):
get_global_step_var()
gs = get_global_step_var()
assert gs.device, gs.device
self.model.get_optimizer() # TODO in global scope, not local
# do this before super.setup because input_source my need global step
super(DistributedReplicatedTrainer, self)._setup()
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariableIfNotPsVar()):
custom_getter=OverrideToLocalVariable()):
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
devices=self.raw_devices,
var_strategy='replicated')
var_strategy='replicated',
vs_names=None) # use the default vs names
avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedTrainer._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads)
self._shadow_vars = [v for (_, v) in ps_var_grads]
main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier(
......@@ -180,7 +181,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to the tf.train.Server constructor.")
# TODO use scaffold
# TODO use scaffold + monitored session
class SupervisedSessionCreator(tf.train.SessionCreator):
def __init__(self, is_chief, target):
self.is_chief = is_chief
......@@ -239,18 +240,17 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
local_vars = tf.local_variables()
local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars])
post_init_ops = []
for v in tf.global_variables():
if v.name.startswith(PS_SHADOW_VAR_PREFIX + '/'):
prefix = strip_port(
v.name[len(PS_SHADOW_VAR_PREFIX + '/'):])
for i in range(self.nr_gpu):
if i == 0:
name = prefix # no prefix for tower0
else:
name = 'tower%s/%s' % (i, prefix)
if name in local_var_by_name:
copy_to = local_var_by_name[name]
post_init_ops.append(copy_to.assign(v.read_value()))
else:
logger.warn("Global varable {} doesn't match a corresponding local var".format(v.name))
for v in self._shadow_vars:
vname = strip_port(v.name)
for i in range(self.nr_gpu):
name = 'tower%s/%s' % (i, vname)
if name in local_var_by_name:
copy_to = local_var_by_name[name]
post_init_ops.append(copy_to.assign(v.read_value()))
else:
logger.warn("Global variable {} doesn't match a corresponding local var".format(v.name))
return tf.group(*post_init_ops, name='sync_variables_from_ps')
@property
def vs_name_for_predictor(self):
return "tower0"
......@@ -71,7 +71,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower
# TODO assumption on the first-tower empty variable scope
# TODO use ctx.vars?
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = opt.compute_gradients(
cost,
......
......@@ -49,13 +49,17 @@ def apply_prefetch_policy(config, use_stage=True):
class MultiGPUTrainerBase(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
def build_on_multi_tower(towers, func, devices=None, var_strategy='shared'):
def build_on_multi_tower(
towers, func,
devices=None, var_strategy='shared',
vs_names=None):
"""
Args:
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers.
var_strategy (str):
var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
Returns:
List of outputs of ``func``, evaluated on each tower.
......@@ -72,13 +76,18 @@ class MultiGPUTrainerBase(Trainer):
if var_strategy == 'replicated': # TODO ugly
logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
else:
assert vs_names is None
if vs_names is None:
vs_names = [None] * len(towers)
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with TowerContext(
'tower{}'.format(idx),
device=device, is_training=True,
var_strategy=var_strategy):
var_strategy=var_strategy,
vs_name=vs_names[idx]):
if idx == t:
logger.info("Building graph for training tower {}...".format(idx))
else:
......@@ -248,7 +257,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
var_strategy='replicated')
var_strategy='replicated',
# use no variable scope for the first tower
vs_names=[''] + [None] * self.config.nr_tower - 1)
grads = self._allreduce_grads(grad_list)
train_ops = []
......
......@@ -3,6 +3,7 @@
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..predict import (OnlinePredictor,
PredictorTowerBuilder)
......@@ -19,6 +20,7 @@ class PredictorFactory(object):
"""
self.model = trainer.model
self.towers = trainer.config.predict_tower
self.vs_name = trainer.vs_name_for_predictor
def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs())
......@@ -34,7 +36,8 @@ class PredictorFactory(object):
"""
tower = self.towers[tower]
# just ensure the tower exists. won't rebuild (memoized)
self._tower_builder.build(tower)
with tf.variable_scope(self.vs_name, reuse=True):
self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment