Commit 99872a7b authored by Yuxin Wu's avatar Yuxin Wu

Let MultiThreadRunner respect its len() (fix #1231)

parent 49fa0a07
......@@ -416,7 +416,8 @@ class MultiThreadRunner(DataFlow):
Args:
get_df ( -> DataFlow): a callable which returns a DataFlow.
Each thread will call this function to get the DataFlow to use.
Therefore do not return the same DataFlow for each call.
Therefore do not return the same DataFlow object for each call,
unless your dataflow is stateless.
num_prefetch (int): size of the queue
num_thread (int): number of threads
nr_prefetch, nr_thread: deprecated names
......@@ -438,6 +439,11 @@ class MultiThreadRunner(DataFlow):
MultiThreadRunner._Worker(get_df, self.queue)
for _ in range(num_thread)]
try:
self._size = self.__len__()
except NotImplementedError:
self._size = -1
def reset_state(self):
for th in self.threads:
th.df.reset_state()
......@@ -447,7 +453,9 @@ class MultiThreadRunner(DataFlow):
return self.threads[0].df.__len__()
def __iter__(self):
while True:
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
yield self.queue.get()
def __del__(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment