Commit 0641ef09 authored by Yuxin Wu's avatar Yuxin Wu

fix MultiThreadPrefetchData

parent a1da74af
...@@ -63,13 +63,13 @@ class InferenceRunnerBase(Callback): ...@@ -63,13 +63,13 @@ class InferenceRunnerBase(Callback):
Note: Note:
1. InferenceRunner will use `input.size()` to determine 1. InferenceRunner will use `input.size()` to determine
how much iterations to run, so you're responsible to ensure that how much iterations to run, so you're responsible to ensure that
`input.size()` is reasonable. `input.size()` is accurate.
2. Only works with instances of `TowerTrainer`. 2. Only works with instances of `TowerTrainer`.
""" """
def __init__(self, input, infs): def __init__(self, input, infs):
""" """
Args: Args:
input (InputSource): the input to use. Must have ``size()``. input (InputSource): the input to use. Must have an accurate ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run. infs (list[Inferencer]): list of :class:`Inferencer` to run.
""" """
self._input_source = input self._input_source = input
......
...@@ -363,10 +363,11 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -363,10 +363,11 @@ class MultiThreadPrefetchData(DataFlow):
def run(self): def run(self):
self.df.reset_state() self.df.reset_state()
try: try:
for dp in self.df: while True:
if self.stopped(): for dp in self.df:
return if self.stopped():
self.queue_put_stoppable(self.queue, dp) return
self.queue_put_stoppable(self.queue, dp)
except Exception: except Exception:
if self.stopped(): if self.stopped():
pass # skip duplicated error messages pass # skip duplicated error messages
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment