Commit 9626ebd8 authored by Yuxin Wu's avatar Yuxin Wu

remove input_names from InferenceRunnerBase

parent 76fa8e38
...@@ -56,13 +56,11 @@ def summary_inferencer(trainer, infs): ...@@ -56,13 +56,11 @@ def summary_inferencer(trainer, infs):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Callback): class InferenceRunnerBase(Callback):
""" Base methods for inference runner""" """ Base methods for inference runner"""
def __init__(self, input, infs, input_names=None, prefix='', extra_hooks=None): def __init__(self, input, infs, prefix='', extra_hooks=None):
""" """
Args: Args:
input (InputSource): the input to use. Must have ``size()``. input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run. infs (list[Inferencer]): list of :class:`Inferencer` to run.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used. differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation. extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
...@@ -74,9 +72,6 @@ class InferenceRunnerBase(Callback): ...@@ -74,9 +72,6 @@ class InferenceRunnerBase(Callback):
self.infs = infs self.infs = infs
for v in self.infs: for v in self.infs:
assert isinstance(v, Inferencer), v assert isinstance(v, Inferencer), v
if input_names is not None:
assert isinstance(input_names, list)
self.input_names = input_names
try: try:
self._size = input.size() self._size = input.size()
...@@ -95,7 +90,7 @@ class InferenceRunnerBase(Callback): ...@@ -95,7 +90,7 @@ class InferenceRunnerBase(Callback):
def fn(_): def fn(_):
in_tensors = self._find_input_tensors() in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors assert isinstance(in_tensors, (list, tuple)), in_tensors
self.trainer.model.build_graph(in_tensors) self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id) PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
...@@ -140,12 +135,13 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -140,12 +135,13 @@ class InferenceRunner(InferenceRunnerBase):
Args: Args:
ds (DataFlow): the DataFlow to run inferencer on. ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances. infs (list): a list of `Inferencer` instances.
input_names (list[str]): same as in :class:`InferenceRunnerBase`. input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
input = FeedInput(ds, input_names) input = FeedInput(ds, input_names)
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(
input, infs, input_names, prefix='', extra_hooks=extra_hooks) input, infs, prefix='', extra_hooks=extra_hooks)
def _find_input_tensors(self): def _find_input_tensors(self):
return self._input_source.get_input_tensors() return self._input_source.get_input_tensors()
...@@ -173,7 +169,10 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -173,7 +169,10 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
""" """
assert isinstance(input, TensorInput), input assert isinstance(input, TensorInput), input
super(FeedfreeInferenceRunner, self).__init__( super(FeedfreeInferenceRunner, self).__init__(
input, infs, input_names, prefix=prefix, extra_hooks=extra_hooks) input, infs, prefix=prefix, extra_hooks=extra_hooks)
if input_names is not None:
assert isinstance(input_names, list)
self.input_names = input_names
def _find_input_tensors(self): def _find_input_tensors(self):
# TODO move mapping to InputSource # TODO move mapping to InputSource
...@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
for k in range(len(gpus))] for k in range(len(gpus))]
input = DataParallelFeedInput( input = DataParallelFeedInput(
ds, self._tower_names, input_names=input_names) ds, self._tower_names, input_names=input_names)
super(DataParallelInferenceRunner, self).__init__( super(DataParallelInferenceRunner, self).__init__(input, infs)
input, infs, input_names)
self._gpus = gpus self._gpus = gpus
def _setup_graph(self): def _setup_graph(self):
......
...@@ -25,7 +25,8 @@ from ..utils.concurrency import ShareSessionThread ...@@ -25,7 +25,8 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
__all__ = ['InputSource', 'FeedfreeInput', 'DataParallelFeedInput', __all__ = ['InputSource', 'FeedfreeInput',
'FeedInput', 'DataParallelFeedInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'ZMQInput', 'ZMQInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper'] 'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
...@@ -54,12 +55,13 @@ class InputSource(object): ...@@ -54,12 +55,13 @@ class InputSource(object):
def reset_state(self): def reset_state(self):
pass pass
@abstractmethod
def next_feed(self): def next_feed(self):
""" """
Returns: Returns:
a feed_dict of {Tensor: data}, to be used to run the steps a feed_dict of {Tensor: data}, to be used to run the steps
""" """
return {} pass
class FeedInput(InputSource): class FeedInput(InputSource):
...@@ -128,6 +130,7 @@ class DataParallelFeedInput(FeedInput): ...@@ -128,6 +130,7 @@ class DataParallelFeedInput(FeedInput):
# input_names to be used for this specific tower # input_names to be used for this specific tower
self._feed_placehdrs_per_tower.append( self._feed_placehdrs_per_tower.append(
get_placeholders_by_names(phdrs, input_names)) get_placeholders_by_names(phdrs, input_names))
print(self._feed_placehdrs_per_tower[-1])
self.reset_state() self.reset_state()
def get_input_tensors(self): def get_input_tensors(self):
...@@ -158,10 +161,13 @@ class FeedfreeInput(InputSource): ...@@ -158,10 +161,13 @@ class FeedfreeInput(InputSource):
# TODO cannot reset # TODO cannot reset
pass pass
def next_feed(self):
return {}
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155 # TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class EnqueueThread(ShareSessionThread): class EnqueueThread(ShareSessionThread):
def __init__(self, queue, ds, input_placehdrs): def __init__(self, queue, ds, placehdrs):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread' self.name = 'EnqueueThread'
self.daemon = True self.daemon = True
...@@ -169,7 +175,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -169,7 +175,7 @@ class EnqueueThread(ShareSessionThread):
self.dataflow = ds self.dataflow = ds
self.queue = queue self.queue = queue
self.placehdrs = input_placehdrs self.placehdrs = placehdrs
self.op = self.queue.enqueue(self.placehdrs) self.op = self.queue.enqueue(self.placehdrs)
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment