Commit 9626ebd8 authored by Yuxin Wu's avatar Yuxin Wu

remove input_names from InferenceRunnerBase

parent 76fa8e38
......@@ -56,13 +56,11 @@ def summary_inferencer(trainer, infs):
@six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Callback):
""" Base methods for inference runner"""
def __init__(self, input, infs, input_names=None, prefix='', extra_hooks=None):
def __init__(self, input, infs, prefix='', extra_hooks=None):
"""
Args:
input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
......@@ -74,9 +72,6 @@ class InferenceRunnerBase(Callback):
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), v
if input_names is not None:
assert isinstance(input_names, list)
self.input_names = input_names
try:
self._size = input.size()
......@@ -95,7 +90,7 @@ class InferenceRunnerBase(Callback):
def fn(_):
in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors
assert isinstance(in_tensors, (list, tuple)), in_tensors
self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
......@@ -140,12 +135,13 @@ class InferenceRunner(InferenceRunnerBase):
Args:
ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
"""
assert isinstance(ds, DataFlow), ds
input = FeedInput(ds, input_names)
super(InferenceRunner, self).__init__(
input, infs, input_names, prefix='', extra_hooks=extra_hooks)
input, infs, prefix='', extra_hooks=extra_hooks)
def _find_input_tensors(self):
return self._input_source.get_input_tensors()
......@@ -173,7 +169,10 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
"""
assert isinstance(input, TensorInput), input
super(FeedfreeInferenceRunner, self).__init__(
input, infs, input_names, prefix=prefix, extra_hooks=extra_hooks)
input, infs, prefix=prefix, extra_hooks=extra_hooks)
if input_names is not None:
assert isinstance(input_names, list)
self.input_names = input_names
def _find_input_tensors(self):
# TODO move mapping to InputSource
......@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
for k in range(len(gpus))]
input = DataParallelFeedInput(
ds, self._tower_names, input_names=input_names)
super(DataParallelInferenceRunner, self).__init__(
input, infs, input_names)
super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus
def _setup_graph(self):
......
......@@ -25,7 +25,8 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback
__all__ = ['InputSource', 'FeedfreeInput', 'DataParallelFeedInput',
__all__ = ['InputSource', 'FeedfreeInput',
'FeedInput', 'DataParallelFeedInput',
'QueueInput', 'BatchQueueInput',
'ZMQInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
......@@ -54,12 +55,13 @@ class InputSource(object):
def reset_state(self):
pass
@abstractmethod
def next_feed(self):
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return {}
pass
class FeedInput(InputSource):
......@@ -128,6 +130,7 @@ class DataParallelFeedInput(FeedInput):
# input_names to be used for this specific tower
self._feed_placehdrs_per_tower.append(
get_placeholders_by_names(phdrs, input_names))
print(self._feed_placehdrs_per_tower[-1])
self.reset_state()
def get_input_tensors(self):
......@@ -158,10 +161,13 @@ class FeedfreeInput(InputSource):
# TODO cannot reset
pass
def next_feed(self):
return {}
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class EnqueueThread(ShareSessionThread):
def __init__(self, queue, ds, input_placehdrs):
def __init__(self, queue, ds, placehdrs):
super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread'
self.daemon = True
......@@ -169,7 +175,7 @@ class EnqueueThread(ShareSessionThread):
self.dataflow = ds
self.queue = queue
self.placehdrs = input_placehdrs
self.placehdrs = placehdrs
self.op = self.queue.enqueue(self.placehdrs)
self.close_op = self.queue.close(cancel_pending_enqueues=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment