Commit 04b52d20 authored by Yuxin Wu's avatar Yuxin Wu

fix InferenceRunner for zero input size (#565)

parent db36b929
......@@ -120,6 +120,7 @@ class ScalarStats(Inferencer):
self.stats.append(output)
def _after_inference(self):
if len(self.stats):
self.stats = np.mean(self.stats, axis=0)
assert len(self.stats) == len(self.names)
......
......@@ -3,6 +3,7 @@
# File: inference_runner.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import sys
import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
......@@ -158,9 +159,12 @@ class InferenceRunner(InferenceRunnerBase):
# iterate over the data, and run the hooked session
self._input_source.reset_state()
with _inference_context():
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
with _inference_context(), \
tqdm.tqdm(total=self._size, **get_tqdm_kwargs()) as pbar:
num_itr = self._size if self._size > 0 else sys.maxsize
for _ in range(num_itr):
self._hooked_sess.run(fetches=[])
pbar.update()
for inf in self.infs:
inf.trigger_epoch()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment