Commit ddebb23c authored by Yuxin Wu's avatar Yuxin Wu

move input_names mapping to InputSource

parent 9626ebd8
...@@ -16,8 +16,7 @@ from ..utils import logger, get_tqdm_kwargs ...@@ -16,8 +16,7 @@ from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..train.input_source import TensorInput, FeedInput, DataParallelFeedInput from ..train.input_source import FeedInput, DataParallelFeedInput, FeedfreeInput
from ..train.utils import get_tensors_inputs
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Callback from .base import Callback
...@@ -89,8 +88,7 @@ class InferenceRunnerBase(Callback): ...@@ -89,8 +88,7 @@ class InferenceRunnerBase(Callback):
self._predict_tower_id = self.trainer.config.predict_tower[0] self._predict_tower_id = self.trainer.config.predict_tower[0]
def fn(_): def fn(_):
in_tensors = self._find_input_tensors() in_tensors = self._input_source.get_input_tensors()
assert isinstance(in_tensors, (list, tuple)), in_tensors
self.trainer.model.build_graph(in_tensors) self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id) PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
...@@ -105,9 +103,6 @@ class InferenceRunnerBase(Callback): ...@@ -105,9 +103,6 @@ class InferenceRunnerBase(Callback):
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, names, self._predict_tower_id, prefix=self._prefix) return get_tensor_fn(placeholder_names, names, self._predict_tower_id, prefix=self._prefix)
def _find_input_tensors(self):
pass
@abstractmethod @abstractmethod
def _build_hook(self, inf): def _build_hook(self, inf):
pass pass
...@@ -143,9 +138,6 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -143,9 +138,6 @@ class InferenceRunner(InferenceRunnerBase):
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(
input, infs, prefix='', extra_hooks=extra_hooks) input, infs, prefix='', extra_hooks=extra_hooks)
def _find_input_tensors(self):
return self._input_source.get_input_tensors()
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_output_tensors()
fetches = self._get_tensors_maybe_in_tower(out_names) fetches = self._get_tensors_maybe_in_tower(out_names)
...@@ -154,34 +146,22 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -154,34 +146,22 @@ class InferenceRunner(InferenceRunnerBase):
class FeedfreeInferenceRunner(InferenceRunnerBase): class FeedfreeInferenceRunner(InferenceRunnerBase):
""" A callback that runs a list of :class:`Inferencer` on some """ A callback that runs a list of :class:`Inferencer` on some
:class:`TensorInput`, such as some tensor from a TensorFlow data reading :class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
pipeline. pipeline.
""" """
def __init__(self, input, infs, input_names=None, prefix='', extra_hooks=None): def __init__(self, input, infs, prefix='', extra_hooks=None):
""" """
Args: Args:
input (TensorInput): the input to use. Must have ``size()``. input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run. infs (list): list of :class:`Inferencer` to run.
input_names (list[str]): same as in :class:`InferenceRunnerBase`. input_names (list[str]): same as in :class:`InferenceRunnerBase`.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used. differently if more than one :class:`FeedfreeInferenceRunner` are used.
""" """
assert isinstance(input, TensorInput), input assert isinstance(input, FeedfreeInput), input
super(FeedfreeInferenceRunner, self).__init__( super(FeedfreeInferenceRunner, self).__init__(
input, infs, prefix=prefix, extra_hooks=extra_hooks) input, infs, prefix=prefix, extra_hooks=extra_hooks)
if input_names is not None:
assert isinstance(input_names, list)
self.input_names = input_names
def _find_input_tensors(self):
# TODO move mapping to InputSource
tensors = self._input_source.get_input_tensors()
placeholders = self.trainer.model.get_reused_placehdrs()
if self.input_names is None:
return tensors
else:
return get_tensors_inputs(placeholders, tensors, self.input_names)
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() # all is tensorname out_names = inf.get_output_tensors() # all is tensorname
......
...@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod ...@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod
import six import six
from six.moves import range, zip from six.moves import range, zip
from .utils import get_placeholders_by_names from .utils import get_placeholders_by_names, get_tensors_inputs
from ..dataflow import DataFlow, RepeatedData from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
...@@ -25,11 +25,12 @@ from ..utils.concurrency import ShareSessionThread ...@@ -25,11 +25,12 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
__all__ = ['InputSource', 'FeedfreeInput', __all__ = ['InputSource',
'FeedInput', 'DataParallelFeedInput', 'FeedInput', 'DataParallelFeedInput',
'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'ZMQInput', 'ZMQInput', 'DummyConstantInput', 'TensorInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper'] 'StagingInputWrapper', 'ReorderInputSource']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -73,6 +74,8 @@ class FeedInput(InputSource): ...@@ -73,6 +74,8 @@ class FeedInput(InputSource):
input_names (list[str]): input names this DataFlow maps to input_names (list[str]): input names this DataFlow maps to
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
if input_names is not None:
assert isinstance(input_names, (list, tuple)), input_names
self.ds = ds self.ds = ds
self._input_names = input_names self._input_names = input_names
...@@ -213,7 +216,9 @@ class QueueInput(FeedfreeInput): ...@@ -213,7 +216,9 @@ class QueueInput(FeedfreeInput):
""" """
Args: Args:
ds(DataFlow): the input DataFlow. ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): Defaults to a FIFO queue of size 50. queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.queue = queue self.queue = queue
...@@ -227,7 +232,7 @@ class QueueInput(FeedfreeInput): ...@@ -227,7 +232,7 @@ class QueueInput(FeedfreeInput):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"QueueInput has to be used with input placeholders!" "QueueInput has to be used with some InputDesc!"
if self.queue is None: if self.queue is None:
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs], 50, [x.dtype for x in self.input_placehdrs],
...@@ -259,7 +264,9 @@ class BatchQueueInput(FeedfreeInput): ...@@ -259,7 +264,9 @@ class BatchQueueInput(FeedfreeInput):
Args: Args:
ds(DataFlow): the input DataFlow. ds(DataFlow): the input DataFlow.
batch_size(int): the batch size. batch_size(int): the batch size.
queue (tf.QueueBase): Defaults to a FIFO queue of size 3000. queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 3000.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.queue = queue self.queue = queue
...@@ -273,7 +280,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -273,7 +280,7 @@ class BatchQueueInput(FeedfreeInput):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with input placeholders!" "BatchQueueInput has to be used with some InputDesc!"
# prepare placeholders without the first dimension # prepare placeholders without the first dimension
placehdrs_nobatch = [] placehdrs_nobatch = []
...@@ -366,6 +373,8 @@ class TensorInput(FeedfreeInput): ...@@ -366,6 +373,8 @@ class TensorInput(FeedfreeInput):
get_tensor_fn: a function which returns a list of input tensors get_tensor_fn: a function which returns a list of input tensors
when called. It will be called under a TowerContext. when called. It will be called under a TowerContext.
size(int): size of this input. Use None to leave it undefined. size(int): size of this input. Use None to leave it undefined.
input_names (list[str]): input names the tensors maps to. Defaults
to be all the inputs of the model.
""" """
self.get_tensor_fn = get_tensor_fn self.get_tensor_fn = get_tensor_fn
if size is not None: if size is not None:
...@@ -491,3 +500,40 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -491,3 +500,40 @@ class StagingInputWrapper(FeedfreeInput):
def get_unstage_op(self): def get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops)) all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs) return tf.group(*all_outputs)
class ReorderInputSource(FeedfreeInput):
"""
When an InputSource only maps to a subset of the InputDesc of the model,
wrap it with :class:`ReorderInputSource`.
"""
def __init__(self, input, names):
"""
Args:
input(TensorInput): a TensorInput, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
assert isinstance(input, TensorInput), input
self._input = input
assert isinstance(names, (list, tuple)), names
self._names = names
def size(self):
return self._input.size()
def setup(self, model):
self._all_placehdrs = model.get_reused_placehdrs()
self._input.setup(model)
def setup_training(self, trainer):
self._all_placehdrs = trainer.model.get_reused_placehdrs()
self._input.setup_training(trainer)
def reset_state(self):
self._input.reset_state()
def get_input_tensors(self):
ret = self._input.get_input_tensors()
return get_tensors_inputs(
self._all_placehdrs, ret, self._names)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment