Commit 7a5b861a authored by Yuxin Wu's avatar Yuxin Wu

Don't require size() in InferenceRunner (#397)

parent be50085f
......@@ -45,12 +45,12 @@ def _inference_context():
msg = "You might need to check your input implementation."
try:
yield
except (StopIteration,
tf.errors.CancelledError,
tf.errors.OutOfRangeError):
except (StopIteration, tf.errors.CancelledError):
logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
raise
except tf.errors.OutOfRangeError: # tf.data reaches an end
pass
class InferenceRunnerBase(Callback):
......@@ -76,9 +76,10 @@ class InferenceRunnerBase(Callback):
try:
self._size = input.size()
logger.info("InferenceRunner will eval {} iterations".format(input.size()))
except NotImplementedError:
raise ValueError("Input used in InferenceRunner must have a size!")
logger.info("InferenceRunner will eval on an InputSource of size {}".format(self._size))
self._size = 0
logger.warn("InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!")
self._hooks = []
......@@ -181,6 +182,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
input = QueueInput(input)
assert isinstance(input, QueueInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs)
assert self._size > 0, "Input for DataParallelInferenceRunner must have a size!"
self._gpus = gpus
def _setup_graph(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment