Commit 4f498c2a authored by Yuxin Wu's avatar Yuxin Wu

remove the 'towerp' constant completely. Use custom predict tower names everywhere.

parent d6db7efa
......@@ -146,7 +146,7 @@ def get_config():
RunOp(lambda: M.reset_lstm_state()),
InferenceRunner(
test_data,
[ScalarStats(['cost'], prefix='test')], prefix='test'),
[ScalarStats(['cost'], prefix='test')], tower_name='InferenceTowerTest'),
RunOp(lambda: M.reset_lstm_state()),
CallbackFactory(
trigger_epoch=lambda self:
......
......@@ -58,12 +58,12 @@ def summary_inferencer(trainer, infs):
@six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Callback):
""" Base methods for inference runner"""
def __init__(self, input, infs, prefix='', extra_hooks=None):
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None, prefix=None):
"""
Args:
input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
prefix(str): an prefix used to build the tower. Must be set
tower_name(str): name scope to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
"""
......@@ -79,7 +79,9 @@ class InferenceRunnerBase(Callback):
self._size = input.size()
except NotImplementedError:
raise ValueError("Input used in InferenceRunner must have a size!")
self._prefix = prefix
self._tower_name = tower_name
if prefix is not None:
self._tower_name = 'InferenceTower' + prefix
if extra_hooks is None:
extra_hooks = []
......@@ -90,14 +92,9 @@ class InferenceRunnerBase(Callback):
tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
# TODO this cannot be InferenceRunner? fix it. check name
tower_name = 'InferenceRunnerTower'
if self._prefix:
tower_name += '_' + self._prefix
self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(tower_name, device, self._input_source)
self._tower_handle = self.trainer.predictor_factory.build(self._tower_name, device, self._input_source)
self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
......@@ -127,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
"""
def __init__(self, input, infs, prefix='', extra_hooks=None):
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None):
"""
Args:
input (InputSource or DataFlow): The :class:`InputSource` to run
......@@ -140,7 +137,7 @@ class InferenceRunner(InferenceRunnerBase):
if isinstance(input, FeedfreeInput): # TODO support other input
assert isinstance(input, TensorInput), "InferenceRunner only accepts TensorInput or FeedInput!"
super(InferenceRunner, self).__init__(
input, infs, prefix=prefix, extra_hooks=extra_hooks)
input, infs, tower_name=tower_name, extra_hooks=extra_hooks)
def _build_hook(self, inf):
out_names = inf.get_output_tensors()
......@@ -153,7 +150,6 @@ def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs)
# TODO some scripts to test
class DataParallelInferenceRunner(InferenceRunnerBase):
"""
Inference by feeding datapoints in a data-parallel way to multiple GPUs.
......@@ -166,7 +162,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
input (DataParallelFeedInput or DataFlow)
gpus (list[int]): list of GPU id
"""
self._tower_names = ['InferenceRunner{}'.format(k) for k in range(len(gpus))]
self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))]
if isinstance(input, DataFlow):
input = DataParallelFeedInput(input, self._tower_names)
assert isinstance(input, DataParallelFeedInput), input
......
......@@ -117,7 +117,7 @@ class Monitors(TrainingMonitor):
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150
suffix = '-summary' # tensorflow#6150, tensorboard#59
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self._dispatch(lambda m: m.put_scalar(val.tag, val.simple_value))
......
......@@ -21,7 +21,7 @@ def class_scope(func):
def get_name_scope_name():
if get_tf_version_number() > 1.2:
return tf.get_name_scope().name
return tf.get_default_graph().get_name_scope()
else:
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
......
......@@ -15,14 +15,14 @@ else:
__all__ = ['get_name_scope_name', 'auto_reuse_variable_scope']
@deprecated("Use tf.get_name_scope() (available since 1.2.1).")
@deprecated("Use tf.get_default_graph().get_name_scope() (available since 1.2.1).")
def get_name_scope_name():
"""
Returns:
str: the name of the current name scope, without the ending '/'.
"""
if get_tf_version_number() > 1.2:
return tf.get_name_scope().name
return tf.get_default_graph().get_name_scope()
else:
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from .common import get_tf_version_number
__all__ = ['get_current_tower_context', 'TowerContext']
......@@ -16,8 +17,7 @@ class TowerContext(object):
def __init__(self, tower_name, is_training=None, index=0, vs_name=''):
"""
Args:
tower_name (str): The name scope of the tower. Currently used
values are like: 'tower0', 'towerp0', or ''
tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower.
vs_name (str): Open a variable scope with this name, if given.
......@@ -106,6 +106,12 @@ class TowerContext(object):
for c in self._ctxs:
c.__enter__()
if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \
"Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \
+ " You may need a different name for the tower!"
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
......
......@@ -9,7 +9,6 @@ import tensorflow as tf
from collections import defaultdict
import numpy as np
from ..utils import logger
from ..utils.naming import PREDICT_TOWER
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
......@@ -29,11 +28,6 @@ def get_savename_from_varname(
str: the name used to save the variable
"""
name = varname
# TODO PREDICT_TOWER is not used anymore
if PREDICT_TOWER in name:
logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER))
# don't overwrite anything in the current prediction graph
return None
if varname_prefix is not None \
and name.startswith(varname_prefix):
name = name[len(varname_prefix) + 1:]
......
......@@ -10,9 +10,6 @@ GLOBAL_STEP_INCR_VAR_NAME = 'global_step_incr:0'
LOCAL_STEP_OP_NAME = 'local_step'
LOCAL_STEP_VAR_NAME = 'local_step:0'
# prefix of predict tower
PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment