Commit 3951aaf7 authored by Yuxin Wu's avatar Yuxin Wu

merge Feed/Feedfree InferenceRunner

parent 05d1cbe7
......@@ -142,9 +142,9 @@ def get_config():
'learning_rate',
lambda e, x: x * 0.80 if e > 6 else x),
RunOp(lambda: M.reset_lstm_state()),
FeedfreeInferenceRunner(val_data, [ScalarStats(['cost'])]),
InferenceRunner(val_data, [ScalarStats(['cost'])]),
RunOp(lambda: M.reset_lstm_state()),
FeedfreeInferenceRunner(
InferenceRunner(
test_data,
[ScalarStats(['cost'], prefix='test')], prefix='test'),
RunOp(lambda: M.reset_lstm_state()),
......
......@@ -13,11 +13,14 @@ import six
from six.moves import range
from ..utils import logger, get_tqdm_kwargs
from ..utils.develop import deprecated
from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput)
FeedInput, DataParallelFeedInput, FeedfreeInput, TensorInput)
from ..predict import PredictorTowerBuilder
from .base import Callback
......@@ -118,21 +121,23 @@ class InferenceRunnerBase(Callback):
class InferenceRunner(InferenceRunnerBase):
"""
A callback that runs a list of :class:`Inferencer` on some
:class:`DataFlow`.
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
"""
def __init__(self, input, infs, extra_hooks=None):
def __init__(self, input, infs, prefix='', extra_hooks=None):
"""
Args:
input (FeedInput or DataFlow): the FeedInput, or the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input (InputSource or DataFlow): The :class:`InputSource` to run
inference on. If given a DataFlow, will use :class:`FeedInput`.
infs (list): a list of :class:`Inferencer` instances.
"""
if isinstance(input, DataFlow):
input = FeedInput(input)
assert isinstance(input, FeedInput), input
assert isinstance(input, InputSource), input
if isinstance(input, FeedfreeInput): # TODO support other input
assert isinstance(input, TensorInput), "InferenceRunner only accepts TensorInput or FeedInput!"
super(InferenceRunner, self).__init__(
input, infs, prefix='', extra_hooks=extra_hooks)
input, infs, prefix=prefix, extra_hooks=extra_hooks)
def _build_hook(self, inf):
out_names = inf.get_output_tensors()
......@@ -140,32 +145,9 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches)
class FeedfreeInferenceRunner(InferenceRunnerBase):
""" A callback that runs a list of :class:`Inferencer` on some
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
pipeline.
"""
def __init__(self, input, infs, prefix='', extra_hooks=None):
"""
Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
assert isinstance(input, FeedfreeInput), input
super(FeedfreeInferenceRunner, self).__init__(
input, infs, prefix=prefix, extra_hooks=extra_hooks)
def _build_hook(self, inf):
out_names = inf.get_output_tensors() # all is tensorname
placeholder_names = [k.name + ':0' for k in self.trainer.model.get_inputs_desc()]
ret = []
for name in out_names:
assert name not in placeholder_names, "Currently inferencer don't support fetching placeholders!"
ret.append(self._tower_handle.get_tensors([name])[0])
return InferencerToHook(inf, ret)
@deprecated("Just use InferenceRunner since it now accepts TensorInput!")
def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs)
# TODO some scripts to test
......
......@@ -14,16 +14,22 @@ __all__ = ['PredictorFactory']
class PredictorTowerHandle(object):
def __init__(self, tower_name, input_tensors):
def __init__(self, tower_name, input_desc_names, input_tensors=None):
self._tower_name = tower_name
self._input_tensors = input_tensors
self._input_names = [get_op_tensor_name(k.name)[1] for k in input_tensors]
self._input_desc_names = [get_op_tensor_name(k)[1] for k in input_desc_names]
if input_tensors is not None:
self._input_names = [get_op_tensor_name(k.name)[1] for k in input_tensors]
else:
self._input_names = self._input_desc_names
def get_tensors(self, names):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[1]
if name in self._input_names:
return name
elif name in self._input_desc_names:
idx = self._input_desc_names.index(name)
return self._input_names[idx]
else:
# if the name is not a placeholder, use it's name in each tower
return self._tower_name + '/' + name
......@@ -62,7 +68,10 @@ class PredictorFactory(object):
input = input.get_input_tensors()
assert isinstance(input, (list, tuple)), input
self._model.build_graph(input)
self._names_built[tower_name] = PredictorTowerHandle(tower_name, input)
desc_names = [k.name for k in self._model.get_inputs_desc()]
self._names_built[tower_name] = PredictorTowerHandle(
tower_name, desc_names, input)
return self._names_built[tower_name]
def has_built(self, tower_name):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment