Commit e5ff50e7 authored by Yuxin Wu's avatar Yuxin Wu

a better TowerTensorHandle to access tensors (currently for predictor only) (#318)

parent 9911b234
......@@ -10,7 +10,8 @@ Feature Requests:
It may not have to be added to tensorpack unless you have a good reason.
3. Note that we don't implement papers at other's requests.
Usage Questions:
Usage questions are like "How do I do [this specific thing] in tensorpack?".
Usage Questions, e.g.:
"How do I do [this specific thing] in tensorpack?"
"Why certain examples need to be written in this way?"
We don't answer general machine learning questions like:
"I want to do [this machine learning task]. What specific things I need to do?"
"I want to do [this machine learning task]. What specific things do I need to do?"
......@@ -4,8 +4,7 @@
import tensorflow as tf
from ..utils import logger
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..tfutils.tower import TowerContext, TowerFuncWrapper
from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import PlaceholderInput
......@@ -13,31 +12,6 @@ from ..input_source import PlaceholderInput
__all__ = []
class PredictorTowerHandle(object):
def __init__(self, tower_name, input_desc_names, input_tensors=None):
self._tower_name = tower_name
self._input_desc_names = [get_op_tensor_name(k)[1] for k in input_desc_names]
if input_tensors is not None:
self._input_names = [get_op_tensor_name(k.name)[1] for k in input_tensors]
else:
self._input_names = self._input_desc_names
def get_tensors(self, names):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[1]
if name in self._input_names:
return name
elif name in self._input_desc_names:
idx = self._input_desc_names.index(name)
return self._input_names[idx]
else:
# if the name is not a placeholder, use it's name in each tower
return self._tower_name + '/' + name
names = list(map(maybe_inside_tower, names))
tensors = get_tensors_by_names(names)
return tensors
class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc`."""
......@@ -68,18 +42,20 @@ class PredictorFactory(object):
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
inputs_desc = self._model.get_inputs_desc()
if input is None:
input = PlaceholderInput()
input.setup(self._model.get_inputs_desc())
input = input.get_input_tensors()
assert isinstance(input, (list, tuple)), input
# TODO still using tensors here instead of inputsource
# can be fixed after having towertensorhandle inside modeldesc
self._model.build_graph(input)
input.setup(inputs_desc)
inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs
def tower_func(*inputs):
self._model.build_graph(inputs)
tower_func = TowerFuncWrapper(tower_func, inputs_desc)
tower_func(*inputs)
desc_names = [k.name for k in self._model.get_inputs_desc()]
self._names_built[tower_name] = PredictorTowerHandle(
tower_name, desc_names, input)
self._names_built[tower_name] = tower_func.towers[0]
return self._names_built[tower_name]
def has_built(self, tower_name):
......
......@@ -117,6 +117,9 @@ def get_op_or_tensor_by_name(name):
Args:
name (list[str] or str): names of operations or tensors.
Raises:
KeyError, if the name doesn't exist
"""
G = tf.get_default_graph()
......
......@@ -4,9 +4,12 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from .common import get_tf_version_number
from six.moves import zip
__all__ = ['get_current_tower_context', 'TowerContext']
from ..utils import logger
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper']
_CurrentTowerContext = None
......@@ -54,6 +57,9 @@ class TowerContext(object):
"""
return self.is_main_training_tower or len(self._vs_name) > 0
# TODO clarify the interface on name/vs_name/ns_name.
# TODO in inference, vs_name may need to be different from ns_name.i
# How to deal with this?
@property
def name(self):
return self._name
......@@ -62,6 +68,10 @@ class TowerContext(object):
def vs_name(self):
return self._vs_name
@property
def ns_name(self):
return self._name
def filter_vars_by_vs_name(self, varlist):
"""
Filter the list and only keep those under the current variable scope.
......@@ -85,13 +95,12 @@ class TowerContext(object):
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
_CurrentTowerContext = self
self._ctxs = []
curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "Nesting TowerContext with an existing variable scope!"
# assert empty name scope as well (>1.2.1?)
assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!"
if len(self._name):
if not self.is_training:
# if not training, should handle reuse outside
......@@ -114,6 +123,7 @@ class TowerContext(object):
c.__enter__()
if get_tf_version_number() >= 1.2:
# check that ns_name is always the same as _name
ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \
"Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \
......@@ -135,3 +145,126 @@ class TowerContext(object):
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
class TowerFuncWrapper(object):
"""
A wrapper around a function which builds one tower (one replicate of the model).
It keeps track of the name scope, variable scope and input/output tensors
each time the function is called.
"""
def __init__(self, tower_fn, inputs_desc=None):
"""
Args:
tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything.
inputs_desc ([InputDesc]): use this to figure out the right name for the input tensors.
"""
self._tower_fn = tower_fn
self._inputs_desc = inputs_desc
self._towers = []
def __call__(self, *args):
ctx = get_current_tower_context()
assert ctx is not None, "Function must be called under TowerContext!"
output = self._tower_fn(*args)
handle = TowerTensorHandle(ctx, args, output, self._inputs_desc)
self._towers.append(handle)
return output
@property
def towers(self):
# TODO another wrapper around towerhandlelist
return self._towers
class TowerTensorHandle(object):
"""
When a function is called multiple times under each tower,
it becomes hard to keep track of the scope and access those tensors
in each tower.
This class provides easy access to the tensors as well as the
inputs/outputs created in each tower.
"""
# TODO hide it from doc
def __init__(self, ctx, input, output, inputs_desc=None):
"""
Don't use it because you never need to create the handle by yourself.
"""
self._ctx = ctx
self._extra_tensor_names = {}
if inputs_desc is not None:
assert len(inputs_desc) == len(input)
self._extra_tensor_names = {
get_op_tensor_name(x.name)[1]: y for x, y in zip(inputs_desc, input)}
self._input = input
self._output = output
@property
def vs_name(self):
return self._ctx.vs_name
@property
def ns_name(self):
return self._ctx.ns_name
def get_tensor(self, name):
"""
Get a tensor in this tower. The name can be:
1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower.
"""
name = get_op_tensor_name(name)[1]
if len(self.ns_name):
name_with_ns = self.ns_name + "/" + name
else:
name_with_ns = name
try:
ret = get_op_or_tensor_by_name(name_with_ns)
except KeyError:
if name in self._extra_tensor_names:
return self._extra_tensor_names[name]
raise
else:
if name in self._extra_tensor_names:
logger.warn(
"'{}' may refer to both the tensor '{}' or the input '{}'.".format(
name, ret.name, self._extra_tensor_names[name].name) +
"Assuming it is the tensor '{}'.".format(ret.name))
return ret
def get_tensors(self, names):
return [self.get_tensor(name) for name in names]
def __getitem__(self, name):
return self.get_tensor(name)
def get_variable(self, name):
"""
Get a variable used in this tower.
"""
name = get_op_tensor_name(name)[1]
if len(self.vs_name):
name_with_vs = self.vs_name + "/" + name
else:
name_with_vs = name
return get_op_or_tensor_by_name(name_with_vs)
@property
def input(self):
"""
The list of input tensors used to build the tower.
"""
return self._input
@property
def output(self):
"""
The output returned by the tower function.
"""
return self._output
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment