Commit e5ff50e7 authored by Yuxin Wu's avatar Yuxin Wu

a better TowerTensorHandle to access tensors (currently for predictor only) (#318)

parent 9911b234
...@@ -10,7 +10,8 @@ Feature Requests: ...@@ -10,7 +10,8 @@ Feature Requests:
It may not have to be added to tensorpack unless you have a good reason. It may not have to be added to tensorpack unless you have a good reason.
3. Note that we don't implement papers at other's requests. 3. Note that we don't implement papers at other's requests.
Usage Questions: Usage Questions, e.g.:
Usage questions are like "How do I do [this specific thing] in tensorpack?". "How do I do [this specific thing] in tensorpack?"
"Why certain examples need to be written in this way?"
We don't answer general machine learning questions like: We don't answer general machine learning questions like:
"I want to do [this machine learning task]. What specific things I need to do?" "I want to do [this machine learning task]. What specific things do I need to do?"
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names from ..tfutils.tower import TowerContext, TowerFuncWrapper
from ..tfutils.tower import TowerContext
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
...@@ -13,31 +12,6 @@ from ..input_source import PlaceholderInput ...@@ -13,31 +12,6 @@ from ..input_source import PlaceholderInput
__all__ = [] __all__ = []
class PredictorTowerHandle(object):
def __init__(self, tower_name, input_desc_names, input_tensors=None):
self._tower_name = tower_name
self._input_desc_names = [get_op_tensor_name(k)[1] for k in input_desc_names]
if input_tensors is not None:
self._input_names = [get_op_tensor_name(k.name)[1] for k in input_tensors]
else:
self._input_names = self._input_desc_names
def get_tensors(self, names):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[1]
if name in self._input_names:
return name
elif name in self._input_desc_names:
idx = self._input_desc_names.index(name)
return self._input_names[idx]
else:
# if the name is not a placeholder, use it's name in each tower
return self._tower_name + '/' + name
names = list(map(maybe_inside_tower, names))
tensors = get_tensors_by_names(names)
return tensors
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc`.""" """ Make predictors from :class:`ModelDesc`."""
...@@ -68,18 +42,20 @@ class PredictorFactory(object): ...@@ -68,18 +42,20 @@ class PredictorFactory(object):
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]): freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used # also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph. # TODO a better way to log and warn about collection change during build_graph.
inputs_desc = self._model.get_inputs_desc()
if input is None: if input is None:
input = PlaceholderInput() input = PlaceholderInput()
input.setup(self._model.get_inputs_desc()) input.setup(inputs_desc)
input = input.get_input_tensors() inputs = input.get_input_tensors()
assert isinstance(input, (list, tuple)), input assert isinstance(inputs, (list, tuple)), inputs
# TODO still using tensors here instead of inputsource
# can be fixed after having towertensorhandle inside modeldesc def tower_func(*inputs):
self._model.build_graph(input) self._model.build_graph(inputs)
tower_func = TowerFuncWrapper(tower_func, inputs_desc)
tower_func(*inputs)
desc_names = [k.name for k in self._model.get_inputs_desc()] self._names_built[tower_name] = tower_func.towers[0]
self._names_built[tower_name] = PredictorTowerHandle(
tower_name, desc_names, input)
return self._names_built[tower_name] return self._names_built[tower_name]
def has_built(self, tower_name): def has_built(self, tower_name):
......
...@@ -117,6 +117,9 @@ def get_op_or_tensor_by_name(name): ...@@ -117,6 +117,9 @@ def get_op_or_tensor_by_name(name):
Args: Args:
name (list[str] or str): names of operations or tensors. name (list[str] or str): names of operations or tensors.
Raises:
KeyError, if the name doesn't exist
""" """
G = tf.get_default_graph() G = tf.get_default_graph()
......
...@@ -4,9 +4,12 @@ ...@@ -4,9 +4,12 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from .common import get_tf_version_number from six.moves import zip
__all__ = ['get_current_tower_context', 'TowerContext'] from ..utils import logger
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper']
_CurrentTowerContext = None _CurrentTowerContext = None
...@@ -54,6 +57,9 @@ class TowerContext(object): ...@@ -54,6 +57,9 @@ class TowerContext(object):
""" """
return self.is_main_training_tower or len(self._vs_name) > 0 return self.is_main_training_tower or len(self._vs_name) > 0
# TODO clarify the interface on name/vs_name/ns_name.
# TODO in inference, vs_name may need to be different from ns_name.i
# How to deal with this?
@property @property
def name(self): def name(self):
return self._name return self._name
...@@ -62,6 +68,10 @@ class TowerContext(object): ...@@ -62,6 +68,10 @@ class TowerContext(object):
def vs_name(self): def vs_name(self):
return self._vs_name return self._vs_name
@property
def ns_name(self):
return self._name
def filter_vars_by_vs_name(self, varlist): def filter_vars_by_vs_name(self, varlist):
""" """
Filter the list and only keep those under the current variable scope. Filter the list and only keep those under the current variable scope.
...@@ -85,13 +95,12 @@ class TowerContext(object): ...@@ -85,13 +95,12 @@ class TowerContext(object):
def __enter__(self): def __enter__(self):
global _CurrentTowerContext global _CurrentTowerContext
assert _CurrentTowerContext is None, \ assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
"Nesting TowerContext!"
_CurrentTowerContext = self _CurrentTowerContext = self
self._ctxs = [] self._ctxs = []
curr_vs = tf.get_variable_scope() curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "Nesting TowerContext with an existing variable scope!" assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!"
# assert empty name scope as well (>1.2.1?)
if len(self._name): if len(self._name):
if not self.is_training: if not self.is_training:
# if not training, should handle reuse outside # if not training, should handle reuse outside
...@@ -114,6 +123,7 @@ class TowerContext(object): ...@@ -114,6 +123,7 @@ class TowerContext(object):
c.__enter__() c.__enter__()
if get_tf_version_number() >= 1.2: if get_tf_version_number() >= 1.2:
# check that ns_name is always the same as _name
ns = tf.get_default_graph().get_name_scope() ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \ assert ns == self._name, \
"Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \ "Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \
...@@ -135,3 +145,126 @@ class TowerContext(object): ...@@ -135,3 +145,126 @@ class TowerContext(object):
def get_current_tower_context(): def get_current_tower_context():
global _CurrentTowerContext global _CurrentTowerContext
return _CurrentTowerContext return _CurrentTowerContext
class TowerFuncWrapper(object):
"""
A wrapper around a function which builds one tower (one replicate of the model).
It keeps track of the name scope, variable scope and input/output tensors
each time the function is called.
"""
def __init__(self, tower_fn, inputs_desc=None):
"""
Args:
tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything.
inputs_desc ([InputDesc]): use this to figure out the right name for the input tensors.
"""
self._tower_fn = tower_fn
self._inputs_desc = inputs_desc
self._towers = []
def __call__(self, *args):
ctx = get_current_tower_context()
assert ctx is not None, "Function must be called under TowerContext!"
output = self._tower_fn(*args)
handle = TowerTensorHandle(ctx, args, output, self._inputs_desc)
self._towers.append(handle)
return output
@property
def towers(self):
# TODO another wrapper around towerhandlelist
return self._towers
class TowerTensorHandle(object):
"""
When a function is called multiple times under each tower,
it becomes hard to keep track of the scope and access those tensors
in each tower.
This class provides easy access to the tensors as well as the
inputs/outputs created in each tower.
"""
# TODO hide it from doc
def __init__(self, ctx, input, output, inputs_desc=None):
"""
Don't use it because you never need to create the handle by yourself.
"""
self._ctx = ctx
self._extra_tensor_names = {}
if inputs_desc is not None:
assert len(inputs_desc) == len(input)
self._extra_tensor_names = {
get_op_tensor_name(x.name)[1]: y for x, y in zip(inputs_desc, input)}
self._input = input
self._output = output
@property
def vs_name(self):
return self._ctx.vs_name
@property
def ns_name(self):
return self._ctx.ns_name
def get_tensor(self, name):
"""
Get a tensor in this tower. The name can be:
1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower.
"""
name = get_op_tensor_name(name)[1]
if len(self.ns_name):
name_with_ns = self.ns_name + "/" + name
else:
name_with_ns = name
try:
ret = get_op_or_tensor_by_name(name_with_ns)
except KeyError:
if name in self._extra_tensor_names:
return self._extra_tensor_names[name]
raise
else:
if name in self._extra_tensor_names:
logger.warn(
"'{}' may refer to both the tensor '{}' or the input '{}'.".format(
name, ret.name, self._extra_tensor_names[name].name) +
"Assuming it is the tensor '{}'.".format(ret.name))
return ret
def get_tensors(self, names):
return [self.get_tensor(name) for name in names]
def __getitem__(self, name):
return self.get_tensor(name)
def get_variable(self, name):
"""
Get a variable used in this tower.
"""
name = get_op_tensor_name(name)[1]
if len(self.vs_name):
name_with_vs = self.vs_name + "/" + name
else:
name_with_vs = name
return get_op_or_tensor_by_name(name_with_vs)
@property
def input(self):
"""
The list of input tensors used to build the tower.
"""
return self._input
@property
def output(self):
"""
The output returned by the tower function.
"""
return self._output
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment