Commit 5b310290 authored by Yuxin Wu's avatar Yuxin Wu

infinite option in FeedInput

parent 3745b4d5
...@@ -115,8 +115,14 @@ class InferenceRunnerBase(Callback): ...@@ -115,8 +115,14 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_source.reset_state() self._input_source.reset_state()
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()): msg = "You might need to check your input implementation."
self._hooked_sess.run(fetches=[]) try:
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
self._hooked_sess.run(fetches=[])
except StopIteration:
raise RuntimeError(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
...@@ -133,7 +139,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -133,7 +139,7 @@ class InferenceRunner(InferenceRunnerBase):
infs (list): a list of :class:`Inferencer` instances. infs (list): a list of :class:`Inferencer` instances.
""" """
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = FeedInput(input) input = FeedInput(input, infinite=False)
assert isinstance(input, InputSource), input assert isinstance(input, InputSource), input
if isinstance(input, FeedfreeInput): # TODO support other input if isinstance(input, FeedfreeInput): # TODO support other input
assert isinstance(input, TensorInput), "InferenceRunner only accepts TensorInput or FeedInput!" assert isinstance(input, TensorInput), "InferenceRunner only accepts TensorInput or FeedInput!"
......
...@@ -65,22 +65,26 @@ class FeedInput(InputSource): ...@@ -65,22 +65,26 @@ class FeedInput(InputSource):
self._ds.reset_state() self._ds.reset_state()
self._itr = self._ds.get_data() self._itr = self._ds.get_data()
def __init__(self, ds): def __init__(self, ds, infinite=True):
""" """
Args: Args:
ds (DataFlow): the input DataFlow. ds (DataFlow): the input DataFlow.
infinite (bool): When set to False, will raise StopIteration when
ds is exhausted.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
# TODO avoid infinite repeat, to allow accurate size handling if infinite:
self._repeat_ds = RepeatedData(self.ds, -1) self._iter_ds = RepeatedData(self.ds, -1)
else:
self._iter_ds = self.ds
def _size(self): def _size(self):
return self.ds.size() return self.ds.size()
def _setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder(prefix='') for v in inputs] self._all_placehdrs = [v.build_placeholder(prefix='') for v in inputs]
self._cb = self._FeedCallback(self._repeat_ds, self._all_placehdrs) self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs)
self.reset_state() self.reset_state()
def _get_input_tensors(self): def _get_input_tensors(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment