Commit 04b52d20 authored by Yuxin Wu's avatar Yuxin Wu

fix InferenceRunner for zero input size (#565)

parent db36b929
...@@ -120,8 +120,9 @@ class ScalarStats(Inferencer): ...@@ -120,8 +120,9 @@ class ScalarStats(Inferencer):
self.stats.append(output) self.stats.append(output)
def _after_inference(self): def _after_inference(self):
self.stats = np.mean(self.stats, axis=0) if len(self.stats):
assert len(self.stats) == len(self.names) self.stats = np.mean(self.stats, axis=0)
assert len(self.stats) == len(self.names)
ret = {} ret = {}
for stat, name in zip(self.stats, self.names): for stat, name in zip(self.stats, self.names):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: inference_runner.py # File: inference_runner.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import sys
import tensorflow as tf import tensorflow as tf
from tensorflow.python.training.monitored_session \ from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession import _HookedSession as HookedSession
...@@ -158,9 +159,12 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -158,9 +159,12 @@ class InferenceRunner(InferenceRunnerBase):
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_source.reset_state() self._input_source.reset_state()
with _inference_context(): with _inference_context(), \
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()): tqdm.tqdm(total=self._size, **get_tqdm_kwargs()) as pbar:
num_itr = self._size if self._size > 0 else sys.maxsize
for _ in range(num_itr):
self._hooked_sess.run(fetches=[]) self._hooked_sess.run(fetches=[])
pbar.update()
for inf in self.infs: for inf in self.infs:
inf.trigger_epoch() inf.trigger_epoch()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment