Commit e121701a authored by Yuxin Wu's avatar Yuxin Wu

Allow TowerContext to use both ns_name and vs_name

parent 0addcdc6
...@@ -149,7 +149,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -149,7 +149,7 @@ class InferenceRunner(InferenceRunnerBase):
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder( SimplePredictBuilder(
ns_name=self._tower_name, ns_name=self._tower_name,
vs_name='', device=0).build( # TODO fix vs_name and maybe device vs_name=self.trainer._main_tower_vs_name, device=0).build(
self._input_source, self.trainer.tower_func) self._input_source, self.trainer.tower_func)
self._tower_handle = self.trainer.tower_func.towers[-1] self._tower_handle = self.trainer.tower_func.towers[-1]
...@@ -224,7 +224,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -224,7 +224,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
tower_name = self._tower_names[idx] tower_name = self._tower_names[idx]
SimplePredictBuilder( SimplePredictBuilder(
ns_name=tower_name, ns_name=tower_name,
vs_name='', device=t).build( # TODO fix vs_name and maybe device vs_name=self.trainer._main_tower_vs_name, device=t).build(
self._input_source, self.trainer.tower_func) self._input_source, self.trainer.tower_func)
self._handles.append(self.trainer.tower_func.towers[-1]) self._handles.append(self.trainer.tower_func.towers[-1])
......
...@@ -26,7 +26,6 @@ class SimplePredictBuilder(GraphBuilder): ...@@ -26,7 +26,6 @@ class SimplePredictBuilder(GraphBuilder):
vs_name (str): vs_name (str):
device (int): device (int):
""" """
# TODO does vs_name work properly here when different from ns_name?
self._ns_name = ns_name self._ns_name = ns_name
self._vs_name = vs_name self._vs_name = vs_name
...@@ -56,7 +55,8 @@ class SimplePredictBuilder(GraphBuilder): ...@@ -56,7 +55,8 @@ class SimplePredictBuilder(GraphBuilder):
with tf.device(self._device), \ with tf.device(self._device), \
self._maybe_open_vs(), \ self._maybe_open_vs(), \
TowerContext(self._ns_name, is_training=False), \ TowerContext(
self._ns_name, is_training=False, vs_name=self._vs_name), \
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]): freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used # also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph. # TODO a better way to log and warn about collection change during build_graph.
......
...@@ -88,7 +88,7 @@ class DataParallelBuilder(GraphBuilder): ...@@ -88,7 +88,7 @@ class DataParallelBuilder(GraphBuilder):
tower_names[idx], tower_names[idx],
is_training=True, is_training=True,
index=idx, index=idx,
use_vs=usevs): vs_name=tower_names[idx] if usevs else ''):
logger.info("Building graph for training tower {} on device {}...".format(idx, device)) logger.info("Building graph for training tower {} on device {}...".format(idx, device))
# When use_vs is True, use LOCAL_VARIABLES, # When use_vs is True, use LOCAL_VARIABLES,
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from six.moves import zip from six.moves import zip
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper'] __all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper']
...@@ -17,30 +18,30 @@ _CurrentTowerContext = None ...@@ -17,30 +18,30 @@ _CurrentTowerContext = None
class TowerContext(object): class TowerContext(object):
""" A context where the current model is being built in. """ """ A context where the current model is being built in. """
def __init__(self, tower_name, is_training, index=0, use_vs=False): def __init__(self, tower_name, is_training, index=0, vs_name=''):
""" """
Args: Args:
tower_name (str): The name scope of the tower. tower_name (str): The name scope of the tower.
is_training (bool): is_training (bool):
index (int): index of this tower, only used in training. index (int): index of this tower, only used in training.
use_vs (bool): Open a new variable scope with this name. vs_name (str): Open a new variable scope with this name.
""" """
self._name = tower_name self._name = tower_name
self._is_training = bool(is_training) self._is_training = bool(is_training)
if not self._is_training: if not self._is_training:
assert index == 0 and not use_vs, \ assert index == 0, \
"use_vs and index are only used in training!" "TowerContext(index) is only valid in training!"
self._index = int(index) self._index = int(index)
if use_vs: self._vs_name = vs_name
self._vs_name = self._name if len(vs_name):
assert len(self._name) assert len(tower_name), "TowerContext(vs_name) cannot be used with an empty tower_name!"
else:
self._vs_name = ''
self._initial_vs_reuse = tf.get_variable_scope().reuse
if self.has_own_variables: if self.has_own_variables:
assert not tf.get_variable_scope().reuse, "reuse=True in tower {}!".format(tower_name) assert not self._initial_vs_reuse, \
"Cannot create tower {} with reuse=True!".format(tower_name)
@property @property
def is_main_training_tower(self): def is_main_training_tower(self):
...@@ -55,7 +56,9 @@ class TowerContext(object): ...@@ -55,7 +56,9 @@ class TowerContext(object):
""" """
Whether this tower is supposed to have its own variables. Whether this tower is supposed to have its own variables.
""" """
return self.is_main_training_tower or len(self._vs_name) > 0 return self.is_main_training_tower or \
(self.is_training and len(self._vs_name) > 0) or \
(not self.is_training and len(self._vs_name) > 0 and not self._initial_vs_reuse)
# TODO clarify the interface on name/vs_name/ns_name. # TODO clarify the interface on name/vs_name/ns_name.
# TODO in inference, vs_name may need to be different from ns_name.i # TODO in inference, vs_name may need to be different from ns_name.i
...@@ -72,6 +75,7 @@ class TowerContext(object): ...@@ -72,6 +75,7 @@ class TowerContext(object):
def ns_name(self): def ns_name(self):
return self._name return self._name
# TODO another method to filter by ns_name
def filter_vars_by_vs_name(self, varlist): def filter_vars_by_vs_name(self, varlist):
""" """
Filter the list and only keep those under the current variable scope. Filter the list and only keep those under the current variable scope.
...@@ -93,32 +97,36 @@ class TowerContext(object): ...@@ -93,32 +97,36 @@ class TowerContext(object):
def index(self): def index(self):
return self._index return self._index
@call_only_once
def _get_scopes(self):
if not len(self._name):
return []
ret = []
# either the Tower was originally created with reuse,
# or a training tower without vs has to use reuse.
reuse = (self.is_training and self._index > 0 and not
self.has_own_variables) or self._initial_vs_reuse
if len(self._vs_name):
ret.append(tf.variable_scope(self._vs_name, reuse=reuse))
else:
if reuse:
ret.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
# always clear existing ns # TODO check existing ns
if len(self._name) and self._name != self._vs_name:
ret.append(tf.name_scope(self._name + '/'))
return ret
def __enter__(self): def __enter__(self):
global _CurrentTowerContext global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!" assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
_CurrentTowerContext = self _CurrentTowerContext = self
self._ctxs = []
curr_vs = tf.get_variable_scope() curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!" assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!"
if len(self._name): self._ctxs = self._get_scopes()
if not self.is_training:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self._ctxs.append(tf.name_scope(None))
self._ctxs.append(tf.name_scope(self._name))
else:
if self.has_own_variables:
if len(self._vs_name):
self._ctxs.append(tf.variable_scope(self._vs_name))
else:
self._ctxs.append(tf.name_scope(self._name))
else:
reuse = self._index > 0
if reuse:
self._ctxs.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
self._ctxs.append(tf.name_scope(self._name))
for c in self._ctxs: for c in self._ctxs:
c.__enter__() c.__enter__()
......
...@@ -269,13 +269,21 @@ class TowerTrainer(Trainer): ...@@ -269,13 +269,21 @@ class TowerTrainer(Trainer):
input.setup(self.inputs_desc) input.setup(self.inputs_desc)
SimplePredictBuilder( SimplePredictBuilder(
ns_name=tower_name, vs_name='', ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func) device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name] tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names) input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names) output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors) return OnlinePredictor(input_tensors, output_tensors)
@property
def _main_tower_vs_name(self):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return ""
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer): class SingleCostTrainer(TowerTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment