Commit e121701a authored by Yuxin Wu's avatar Yuxin Wu

Allow TowerContext to use both ns_name and vs_name

parent 0addcdc6
......@@ -149,7 +149,7 @@ class InferenceRunner(InferenceRunnerBase):
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=self._tower_name,
vs_name='', device=0).build( # TODO fix vs_name and maybe device
vs_name=self.trainer._main_tower_vs_name, device=0).build(
self._input_source, self.trainer.tower_func)
self._tower_handle = self.trainer.tower_func.towers[-1]
......@@ -224,7 +224,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
tower_name = self._tower_names[idx]
SimplePredictBuilder(
ns_name=tower_name,
vs_name='', device=t).build( # TODO fix vs_name and maybe device
vs_name=self.trainer._main_tower_vs_name, device=t).build(
self._input_source, self.trainer.tower_func)
self._handles.append(self.trainer.tower_func.towers[-1])
......
......@@ -26,7 +26,6 @@ class SimplePredictBuilder(GraphBuilder):
vs_name (str):
device (int):
"""
# TODO does vs_name work properly here when different from ns_name?
self._ns_name = ns_name
self._vs_name = vs_name
......@@ -56,7 +55,8 @@ class SimplePredictBuilder(GraphBuilder):
with tf.device(self._device), \
self._maybe_open_vs(), \
TowerContext(self._ns_name, is_training=False), \
TowerContext(
self._ns_name, is_training=False, vs_name=self._vs_name), \
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
......
......@@ -88,7 +88,7 @@ class DataParallelBuilder(GraphBuilder):
tower_names[idx],
is_training=True,
index=idx,
use_vs=usevs):
vs_name=tower_names[idx] if usevs else ''):
logger.info("Building graph for training tower {} on device {}...".format(idx, device))
# When use_vs is True, use LOCAL_VARIABLES,
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
from six.moves import zip
from ..utils import logger
from ..utils.argtools import call_only_once
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper']
......@@ -17,30 +18,30 @@ _CurrentTowerContext = None
class TowerContext(object):
""" A context where the current model is being built in. """
def __init__(self, tower_name, is_training, index=0, use_vs=False):
def __init__(self, tower_name, is_training, index=0, vs_name=''):
"""
Args:
tower_name (str): The name scope of the tower.
is_training (bool):
index (int): index of this tower, only used in training.
use_vs (bool): Open a new variable scope with this name.
vs_name (str): Open a new variable scope with this name.
"""
self._name = tower_name
self._is_training = bool(is_training)
if not self._is_training:
assert index == 0 and not use_vs, \
"use_vs and index are only used in training!"
assert index == 0, \
"TowerContext(index) is only valid in training!"
self._index = int(index)
if use_vs:
self._vs_name = self._name
assert len(self._name)
else:
self._vs_name = ''
self._vs_name = vs_name
if len(vs_name):
assert len(tower_name), "TowerContext(vs_name) cannot be used with an empty tower_name!"
self._initial_vs_reuse = tf.get_variable_scope().reuse
if self.has_own_variables:
assert not tf.get_variable_scope().reuse, "reuse=True in tower {}!".format(tower_name)
assert not self._initial_vs_reuse, \
"Cannot create tower {} with reuse=True!".format(tower_name)
@property
def is_main_training_tower(self):
......@@ -55,7 +56,9 @@ class TowerContext(object):
"""
Whether this tower is supposed to have its own variables.
"""
return self.is_main_training_tower or len(self._vs_name) > 0
return self.is_main_training_tower or \
(self.is_training and len(self._vs_name) > 0) or \
(not self.is_training and len(self._vs_name) > 0 and not self._initial_vs_reuse)
# TODO clarify the interface on name/vs_name/ns_name.
# TODO in inference, vs_name may need to be different from ns_name.i
......@@ -72,6 +75,7 @@ class TowerContext(object):
def ns_name(self):
return self._name
# TODO another method to filter by ns_name
def filter_vars_by_vs_name(self, varlist):
"""
Filter the list and only keep those under the current variable scope.
......@@ -93,32 +97,36 @@ class TowerContext(object):
def index(self):
return self._index
@call_only_once
def _get_scopes(self):
if not len(self._name):
return []
ret = []
# either the Tower was originally created with reuse,
# or a training tower without vs has to use reuse.
reuse = (self.is_training and self._index > 0 and not
self.has_own_variables) or self._initial_vs_reuse
if len(self._vs_name):
ret.append(tf.variable_scope(self._vs_name, reuse=reuse))
else:
if reuse:
ret.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
# always clear existing ns # TODO check existing ns
if len(self._name) and self._name != self._vs_name:
ret.append(tf.name_scope(self._name + '/'))
return ret
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
_CurrentTowerContext = self
self._ctxs = []
curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!"
if len(self._name):
if not self.is_training:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.name_scope(self._name))
else:
if self.has_own_variables:
if len(self._vs_name):
self._ctxs.append(tf.variable_scope(self._vs_name))
else:
self._ctxs.append(tf.name_scope(self._name))
else:
reuse = self._index > 0
if reuse:
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
self._ctxs.append(tf.name_scope(self._name))
self._ctxs = self._get_scopes()
for c in self._ctxs:
c.__enter__()
......
......@@ -269,13 +269,21 @@ class TowerTrainer(Trainer):
input.setup(self.inputs_desc)
SimplePredictBuilder(
ns_name=tower_name, vs_name='',
ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors)
@property
def _main_tower_vs_name(self):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return ""
@six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment