Commit 0641ef09 authored by Yuxin Wu's avatar Yuxin Wu

fix MultiThreadPrefetchData

parent a1da74af
......@@ -63,13 +63,13 @@ class InferenceRunnerBase(Callback):
Note:
1. InferenceRunner will use `input.size()` to determine
how much iterations to run, so you're responsible to ensure that
`input.size()` is reasonable.
`input.size()` is accurate.
2. Only works with instances of `TowerTrainer`.
"""
def __init__(self, input, infs):
"""
Args:
input (InputSource): the input to use. Must have ``size()``.
input (InputSource): the input to use. Must have an accurate ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
"""
self._input_source = input
......
......@@ -363,10 +363,11 @@ class MultiThreadPrefetchData(DataFlow):
def run(self):
self.df.reset_state()
try:
for dp in self.df:
if self.stopped():
return
self.queue_put_stoppable(self.queue, dp)
while True:
for dp in self.df:
if self.stopped():
return
self.queue_put_stoppable(self.queue, dp)
except Exception:
if self.stopped():
pass # skip duplicated error messages
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment