Commit 28e8f8db authored by Yuxin Wu's avatar Yuxin Wu

prefetchdata shouldn't require .size()

parent 4000f5d5
......@@ -4,6 +4,7 @@
import multiprocessing
from threading import Thread
import itertools
from six.moves import range
from six.moves.queue import Queue
import uuid
......@@ -52,7 +53,10 @@ class PrefetchData(ProxyDataFlow):
of data points will be random.
"""
super(PrefetchData, self).__init__(ds)
self._size = self.size()
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch
self.queue = multiprocessing.Queue(self.nr_prefetch)
......@@ -63,7 +67,9 @@ class PrefetchData(ProxyDataFlow):
x.start()
def get_data(self):
for _ in range(self._size):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = self.queue.get()
yield dp
......@@ -96,7 +102,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
of datapoints will be random.
"""
super(PrefetchDataZMQ, self).__init__(ds)
self._size = ds.size()
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
self.nr_proc = nr_proc
self.context = zmq.Context()
......@@ -114,7 +123,9 @@ class PrefetchDataZMQ(ProxyDataFlow):
atexit.register(lambda x: x.__del__(), self)
def get_data(self):
for _ in range(self._size):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = loads(self.socket.recv(copy=False))
yield dp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment