Commit ddebb23c authored by Yuxin Wu's avatar Yuxin Wu

move input_names mapping to InputSource

parent 9626ebd8
......@@ -16,8 +16,7 @@ from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..train.input_source import TensorInput, FeedInput, DataParallelFeedInput
from ..train.utils import get_tensors_inputs
from ..train.input_source import FeedInput, DataParallelFeedInput, FeedfreeInput
from ..predict import PredictorTowerBuilder
from .base import Callback
......@@ -89,8 +88,7 @@ class InferenceRunnerBase(Callback):
self._predict_tower_id = self.trainer.config.predict_tower[0]
def fn(_):
in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, (list, tuple)), in_tensors
in_tensors = self._input_source.get_input_tensors()
self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
......@@ -105,9 +103,6 @@ class InferenceRunnerBase(Callback):
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, names, self._predict_tower_id, prefix=self._prefix)
def _find_input_tensors(self):
pass
@abstractmethod
def _build_hook(self, inf):
pass
......@@ -143,9 +138,6 @@ class InferenceRunner(InferenceRunnerBase):
super(InferenceRunner, self).__init__(
input, infs, prefix='', extra_hooks=extra_hooks)
def _find_input_tensors(self):
return self._input_source.get_input_tensors()
def _build_hook(self, inf):
out_names = inf.get_output_tensors()
fetches = self._get_tensors_maybe_in_tower(out_names)
......@@ -154,34 +146,22 @@ class InferenceRunner(InferenceRunnerBase):
class FeedfreeInferenceRunner(InferenceRunnerBase):
""" A callback that runs a list of :class:`Inferencer` on some
:class:`TensorInput`, such as some tensor from a TensorFlow data reading
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
pipeline.
"""
def __init__(self, input, infs, input_names=None, prefix='', extra_hooks=None):
def __init__(self, input, infs, prefix='', extra_hooks=None):
"""
Args:
input (TensorInput): the input to use. Must have ``size()``.
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
assert isinstance(input, TensorInput), input
assert isinstance(input, FeedfreeInput), input
super(FeedfreeInferenceRunner, self).__init__(
input, infs, prefix=prefix, extra_hooks=extra_hooks)
if input_names is not None:
assert isinstance(input_names, list)
self.input_names = input_names
def _find_input_tensors(self):
# TODO move mapping to InputSource
tensors = self._input_source.get_input_tensors()
placeholders = self.trainer.model.get_reused_placehdrs()
if self.input_names is None:
return tensors
else:
return get_tensors_inputs(placeholders, tensors, self.input_names)
def _build_hook(self, inf):
out_names = inf.get_output_tensors() # all is tensorname
......
......@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod
import six
from six.moves import range, zip
from .utils import get_placeholders_by_names
from .utils import get_placeholders_by_names, get_tensors_inputs
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
......@@ -25,11 +25,12 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback
__all__ = ['InputSource', 'FeedfreeInput',
__all__ = ['InputSource',
'FeedInput', 'DataParallelFeedInput',
'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'ZMQInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
'ZMQInput', 'DummyConstantInput', 'TensorInput',
'StagingInputWrapper', 'ReorderInputSource']
@six.add_metaclass(ABCMeta)
......@@ -73,6 +74,8 @@ class FeedInput(InputSource):
input_names (list[str]): input names this DataFlow maps to
"""
assert isinstance(ds, DataFlow), ds
if input_names is not None:
assert isinstance(input_names, (list, tuple)), input_names
self.ds = ds
self._input_names = input_names
......@@ -213,7 +216,9 @@ class QueueInput(FeedfreeInput):
"""
Args:
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): Defaults to a FIFO queue of size 50.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
......@@ -227,7 +232,7 @@ class QueueInput(FeedfreeInput):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"QueueInput has to be used with input placeholders!"
"QueueInput has to be used with some InputDesc!"
if self.queue is None:
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs],
......@@ -259,7 +264,9 @@ class BatchQueueInput(FeedfreeInput):
Args:
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
queue (tf.QueueBase): Defaults to a FIFO queue of size 3000.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 3000.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
......@@ -273,7 +280,7 @@ class BatchQueueInput(FeedfreeInput):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with input placeholders!"
"BatchQueueInput has to be used with some InputDesc!"
# prepare placeholders without the first dimension
placehdrs_nobatch = []
......@@ -366,6 +373,8 @@ class TensorInput(FeedfreeInput):
get_tensor_fn: a function which returns a list of input tensors
when called. It will be called under a TowerContext.
size(int): size of this input. Use None to leave it undefined.
input_names (list[str]): input names the tensors maps to. Defaults
to be all the inputs of the model.
"""
self.get_tensor_fn = get_tensor_fn
if size is not None:
......@@ -491,3 +500,40 @@ class StagingInputWrapper(FeedfreeInput):
def get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs)
class ReorderInputSource(FeedfreeInput):
"""
When an InputSource only maps to a subset of the InputDesc of the model,
wrap it with :class:`ReorderInputSource`.
"""
def __init__(self, input, names):
"""
Args:
input(TensorInput): a TensorInput, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
assert isinstance(input, TensorInput), input
self._input = input
assert isinstance(names, (list, tuple)), names
self._names = names
def size(self):
return self._input.size()
def setup(self, model):
self._all_placehdrs = model.get_reused_placehdrs()
self._input.setup(model)
def setup_training(self, trainer):
self._all_placehdrs = trainer.model.get_reused_placehdrs()
self._input.setup_training(trainer)
def reset_state(self):
self._input.reset_state()
def get_input_tensors(self):
ret = self._input.get_input_tensors()
return get_tensors_inputs(
self._all_placehdrs, ret, self._names)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment