Commit 4f498c2a authored by Yuxin Wu's avatar Yuxin Wu

remove the 'towerp' constant completely. Use custom predict tower names everywhere.

parent d6db7efa
...@@ -146,7 +146,7 @@ def get_config(): ...@@ -146,7 +146,7 @@ def get_config():
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
InferenceRunner( InferenceRunner(
test_data, test_data,
[ScalarStats(['cost'], prefix='test')], prefix='test'), [ScalarStats(['cost'], prefix='test')], tower_name='InferenceTowerTest'),
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
CallbackFactory( CallbackFactory(
trigger_epoch=lambda self: trigger_epoch=lambda self:
......
...@@ -58,12 +58,12 @@ def summary_inferencer(trainer, infs): ...@@ -58,12 +58,12 @@ def summary_inferencer(trainer, infs):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Callback): class InferenceRunnerBase(Callback):
""" Base methods for inference runner""" """ Base methods for inference runner"""
def __init__(self, input, infs, prefix='', extra_hooks=None): def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None, prefix=None):
""" """
Args: Args:
input (InputSource): the input to use. Must have ``size()``. input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run. infs (list[Inferencer]): list of :class:`Inferencer` to run.
prefix(str): an prefix used to build the tower. Must be set tower_name(str): name scope to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used. differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation. extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
""" """
...@@ -79,7 +79,9 @@ class InferenceRunnerBase(Callback): ...@@ -79,7 +79,9 @@ class InferenceRunnerBase(Callback):
self._size = input.size() self._size = input.size()
except NotImplementedError: except NotImplementedError:
raise ValueError("Input used in InferenceRunner must have a size!") raise ValueError("Input used in InferenceRunner must have a size!")
self._prefix = prefix self._tower_name = tower_name
if prefix is not None:
self._tower_name = 'InferenceTower' + prefix
if extra_hooks is None: if extra_hooks is None:
extra_hooks = [] extra_hooks = []
...@@ -90,14 +92,9 @@ class InferenceRunnerBase(Callback): ...@@ -90,14 +92,9 @@ class InferenceRunnerBase(Callback):
tower_id = self.trainer.config.predict_tower[0] tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
# TODO this cannot be InferenceRunner? fix it. check name
tower_name = 'InferenceRunnerTower'
if self._prefix:
tower_name += '_' + self._prefix
self._input_source.setup(self.trainer.model.get_inputs_desc()) self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(tower_name, device, self._input_source) self._tower_handle = self.trainer.predictor_factory.build(self._tower_name, device, self._input_source)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks() cbs = self._input_source.get_callbacks()
...@@ -127,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -127,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`. A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
""" """
def __init__(self, input, infs, prefix='', extra_hooks=None): def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None):
""" """
Args: Args:
input (InputSource or DataFlow): The :class:`InputSource` to run input (InputSource or DataFlow): The :class:`InputSource` to run
...@@ -140,7 +137,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -140,7 +137,7 @@ class InferenceRunner(InferenceRunnerBase):
if isinstance(input, FeedfreeInput): # TODO support other input if isinstance(input, FeedfreeInput): # TODO support other input
assert isinstance(input, TensorInput), "InferenceRunner only accepts TensorInput or FeedInput!" assert isinstance(input, TensorInput), "InferenceRunner only accepts TensorInput or FeedInput!"
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(
input, infs, prefix=prefix, extra_hooks=extra_hooks) input, infs, tower_name=tower_name, extra_hooks=extra_hooks)
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_output_tensors()
...@@ -153,7 +150,6 @@ def FeedfreeInferenceRunner(*args, **kwargs): ...@@ -153,7 +150,6 @@ def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs) return InferenceRunner(*args, **kwargs)
# TODO some scripts to test
class DataParallelInferenceRunner(InferenceRunnerBase): class DataParallelInferenceRunner(InferenceRunnerBase):
""" """
Inference by feeding datapoints in a data-parallel way to multiple GPUs. Inference by feeding datapoints in a data-parallel way to multiple GPUs.
...@@ -166,7 +162,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -166,7 +162,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
input (DataParallelFeedInput or DataFlow) input (DataParallelFeedInput or DataFlow)
gpus (list[int]): list of GPU id gpus (list[int]): list of GPU id
""" """
self._tower_names = ['InferenceRunner{}'.format(k) for k in range(len(gpus))] self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))]
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = DataParallelFeedInput(input, self._tower_names) input = DataParallelFeedInput(input, self._tower_names)
assert isinstance(input, DataParallelFeedInput), input assert isinstance(input, DataParallelFeedInput), input
......
...@@ -117,7 +117,7 @@ class Monitors(TrainingMonitor): ...@@ -117,7 +117,7 @@ class Monitors(TrainingMonitor):
for val in summary.value: for val in summary.value:
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150 suffix = '-summary' # tensorflow#6150, tensorboard#59
if val.tag.endswith(suffix): if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)] val.tag = val.tag[:-len(suffix)]
self._dispatch(lambda m: m.put_scalar(val.tag, val.simple_value)) self._dispatch(lambda m: m.put_scalar(val.tag, val.simple_value))
......
...@@ -21,7 +21,7 @@ def class_scope(func): ...@@ -21,7 +21,7 @@ def class_scope(func):
def get_name_scope_name(): def get_name_scope_name():
if get_tf_version_number() > 1.2: if get_tf_version_number() > 1.2:
return tf.get_name_scope().name return tf.get_default_graph().get_name_scope()
else: else:
g = tf.get_default_graph() g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG" s = "RANDOM_STR_ABCDEFG"
......
...@@ -15,14 +15,14 @@ else: ...@@ -15,14 +15,14 @@ else:
__all__ = ['get_name_scope_name', 'auto_reuse_variable_scope'] __all__ = ['get_name_scope_name', 'auto_reuse_variable_scope']
@deprecated("Use tf.get_name_scope() (available since 1.2.1).") @deprecated("Use tf.get_default_graph().get_name_scope() (available since 1.2.1).")
def get_name_scope_name(): def get_name_scope_name():
""" """
Returns: Returns:
str: the name of the current name scope, without the ending '/'. str: the name of the current name scope, without the ending '/'.
""" """
if get_tf_version_number() > 1.2: if get_tf_version_number() > 1.2:
return tf.get_name_scope().name return tf.get_default_graph().get_name_scope()
else: else:
g = tf.get_default_graph() g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG" s = "RANDOM_STR_ABCDEFG"
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from .common import get_tf_version_number
__all__ = ['get_current_tower_context', 'TowerContext'] __all__ = ['get_current_tower_context', 'TowerContext']
...@@ -16,8 +17,7 @@ class TowerContext(object): ...@@ -16,8 +17,7 @@ class TowerContext(object):
def __init__(self, tower_name, is_training=None, index=0, vs_name=''): def __init__(self, tower_name, is_training=None, index=0, vs_name=''):
""" """
Args: Args:
tower_name (str): The name scope of the tower. Currently used tower_name (str): The name scope of the tower.
values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower. index (int): index of this tower.
vs_name (str): Open a variable scope with this name, if given. vs_name (str): Open a variable scope with this name, if given.
...@@ -106,6 +106,12 @@ class TowerContext(object): ...@@ -106,6 +106,12 @@ class TowerContext(object):
for c in self._ctxs: for c in self._ctxs:
c.__enter__() c.__enter__()
if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \
"Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \
+ " You may need a different name for the tower!"
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext global _CurrentTowerContext
_CurrentTowerContext = None _CurrentTowerContext = None
......
...@@ -9,7 +9,6 @@ import tensorflow as tf ...@@ -9,7 +9,6 @@ import tensorflow as tf
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
from ..utils import logger from ..utils import logger
from ..utils.naming import PREDICT_TOWER
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
...@@ -29,11 +28,6 @@ def get_savename_from_varname( ...@@ -29,11 +28,6 @@ def get_savename_from_varname(
str: the name used to save the variable str: the name used to save the variable
""" """
name = varname name = varname
# TODO PREDICT_TOWER is not used anymore
if PREDICT_TOWER in name:
logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER))
# don't overwrite anything in the current prediction graph
return None
if varname_prefix is not None \ if varname_prefix is not None \
and name.startswith(varname_prefix): and name.startswith(varname_prefix):
name = name[len(varname_prefix) + 1:] name = name[len(varname_prefix) + 1:]
......
...@@ -10,9 +10,6 @@ GLOBAL_STEP_INCR_VAR_NAME = 'global_step_incr:0' ...@@ -10,9 +10,6 @@ GLOBAL_STEP_INCR_VAR_NAME = 'global_step_incr:0'
LOCAL_STEP_OP_NAME = 'local_step' LOCAL_STEP_OP_NAME = 'local_step'
LOCAL_STEP_VAR_NAME = 'local_step:0' LOCAL_STEP_VAR_NAME = 'local_step:0'
# prefix of predict tower
PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS' MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment