Commit 3951aaf7 authored by Yuxin Wu's avatar Yuxin Wu

merge Feed/Feedfree InferenceRunner

parent 05d1cbe7
...@@ -142,9 +142,9 @@ def get_config(): ...@@ -142,9 +142,9 @@ def get_config():
'learning_rate', 'learning_rate',
lambda e, x: x * 0.80 if e > 6 else x), lambda e, x: x * 0.80 if e > 6 else x),
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
FeedfreeInferenceRunner(val_data, [ScalarStats(['cost'])]), InferenceRunner(val_data, [ScalarStats(['cost'])]),
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
FeedfreeInferenceRunner( InferenceRunner(
test_data, test_data,
[ScalarStats(['cost'], prefix='test')], prefix='test'), [ScalarStats(['cost'], prefix='test')], prefix='test'),
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
......
...@@ -13,11 +13,14 @@ import six ...@@ -13,11 +13,14 @@ import six
from six.moves import range from six.moves import range
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.develop import deprecated
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import ( from ..graph_builder.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput) FeedInput, DataParallelFeedInput, FeedfreeInput, TensorInput)
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Callback from .base import Callback
...@@ -118,21 +121,23 @@ class InferenceRunnerBase(Callback): ...@@ -118,21 +121,23 @@ class InferenceRunnerBase(Callback):
class InferenceRunner(InferenceRunnerBase): class InferenceRunner(InferenceRunnerBase):
""" """
A callback that runs a list of :class:`Inferencer` on some A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
:class:`DataFlow`.
""" """
def __init__(self, input, infs, extra_hooks=None): def __init__(self, input, infs, prefix='', extra_hooks=None):
""" """
Args: Args:
input (FeedInput or DataFlow): the FeedInput, or the DataFlow to run inferencer on. input (InputSource or DataFlow): The :class:`InputSource` to run
infs (list): a list of `Inferencer` instances. inference on. If given a DataFlow, will use :class:`FeedInput`.
infs (list): a list of :class:`Inferencer` instances.
""" """
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = FeedInput(input) input = FeedInput(input)
assert isinstance(input, FeedInput), input assert isinstance(input, InputSource), input
if isinstance(input, FeedfreeInput): # TODO support other input
assert isinstance(input, TensorInput), "InferenceRunner only accepts TensorInput or FeedInput!"
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(
input, infs, prefix='', extra_hooks=extra_hooks) input, infs, prefix=prefix, extra_hooks=extra_hooks)
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_output_tensors()
...@@ -140,32 +145,9 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -140,32 +145,9 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
class FeedfreeInferenceRunner(InferenceRunnerBase): @deprecated("Just use InferenceRunner since it now accepts TensorInput!")
""" A callback that runs a list of :class:`Inferencer` on some def FeedfreeInferenceRunner(*args, **kwargs):
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading return InferenceRunner(*args, **kwargs)
pipeline.
"""
def __init__(self, input, infs, prefix='', extra_hooks=None):
"""
Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
assert isinstance(input, FeedfreeInput), input
super(FeedfreeInferenceRunner, self).__init__(
input, infs, prefix=prefix, extra_hooks=extra_hooks)
def _build_hook(self, inf):
out_names = inf.get_output_tensors() # all is tensorname
placeholder_names = [k.name + ':0' for k in self.trainer.model.get_inputs_desc()]
ret = []
for name in out_names:
assert name not in placeholder_names, "Currently inferencer don't support fetching placeholders!"
ret.append(self._tower_handle.get_tensors([name])[0])
return InferencerToHook(inf, ret)
# TODO some scripts to test # TODO some scripts to test
......
...@@ -14,16 +14,22 @@ __all__ = ['PredictorFactory'] ...@@ -14,16 +14,22 @@ __all__ = ['PredictorFactory']
class PredictorTowerHandle(object): class PredictorTowerHandle(object):
def __init__(self, tower_name, input_tensors): def __init__(self, tower_name, input_desc_names, input_tensors=None):
self._tower_name = tower_name self._tower_name = tower_name
self._input_tensors = input_tensors self._input_desc_names = [get_op_tensor_name(k)[1] for k in input_desc_names]
self._input_names = [get_op_tensor_name(k.name)[1] for k in input_tensors] if input_tensors is not None:
self._input_names = [get_op_tensor_name(k.name)[1] for k in input_tensors]
else:
self._input_names = self._input_desc_names
def get_tensors(self, names): def get_tensors(self, names):
def maybe_inside_tower(name): def maybe_inside_tower(name):
name = get_op_tensor_name(name)[1] name = get_op_tensor_name(name)[1]
if name in self._input_names: if name in self._input_names:
return name return name
elif name in self._input_desc_names:
idx = self._input_desc_names.index(name)
return self._input_names[idx]
else: else:
# if the name is not a placeholder, use it's name in each tower # if the name is not a placeholder, use it's name in each tower
return self._tower_name + '/' + name return self._tower_name + '/' + name
...@@ -62,7 +68,10 @@ class PredictorFactory(object): ...@@ -62,7 +68,10 @@ class PredictorFactory(object):
input = input.get_input_tensors() input = input.get_input_tensors()
assert isinstance(input, (list, tuple)), input assert isinstance(input, (list, tuple)), input
self._model.build_graph(input) self._model.build_graph(input)
self._names_built[tower_name] = PredictorTowerHandle(tower_name, input)
desc_names = [k.name for k in self._model.get_inputs_desc()]
self._names_built[tower_name] = PredictorTowerHandle(
tower_name, desc_names, input)
return self._names_built[tower_name] return self._names_built[tower_name]
def has_built(self, tower_name): def has_built(self, tower_name):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment