Commit 99872a7b authored by Yuxin Wu's avatar Yuxin Wu

Let MultiThreadRunner respect its len() (fix #1231)

parent 49fa0a07
...@@ -416,7 +416,8 @@ class MultiThreadRunner(DataFlow): ...@@ -416,7 +416,8 @@ class MultiThreadRunner(DataFlow):
Args: Args:
get_df ( -> DataFlow): a callable which returns a DataFlow. get_df ( -> DataFlow): a callable which returns a DataFlow.
Each thread will call this function to get the DataFlow to use. Each thread will call this function to get the DataFlow to use.
Therefore do not return the same DataFlow for each call. Therefore do not return the same DataFlow object for each call,
unless your dataflow is stateless.
num_prefetch (int): size of the queue num_prefetch (int): size of the queue
num_thread (int): number of threads num_thread (int): number of threads
nr_prefetch, nr_thread: deprecated names nr_prefetch, nr_thread: deprecated names
...@@ -438,6 +439,11 @@ class MultiThreadRunner(DataFlow): ...@@ -438,6 +439,11 @@ class MultiThreadRunner(DataFlow):
MultiThreadRunner._Worker(get_df, self.queue) MultiThreadRunner._Worker(get_df, self.queue)
for _ in range(num_thread)] for _ in range(num_thread)]
try:
self._size = self.__len__()
except NotImplementedError:
self._size = -1
def reset_state(self): def reset_state(self):
for th in self.threads: for th in self.threads:
th.df.reset_state() th.df.reset_state()
...@@ -447,7 +453,9 @@ class MultiThreadRunner(DataFlow): ...@@ -447,7 +453,9 @@ class MultiThreadRunner(DataFlow):
return self.threads[0].df.__len__() return self.threads[0].df.__len__()
def __iter__(self): def __iter__(self):
while True: for k in itertools.count():
if self._size > 0 and k >= self._size:
break
yield self.queue.get() yield self.queue.get()
def __del__(self): def __del__(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment