Commit 28e8f8db authored by Yuxin Wu's avatar Yuxin Wu

prefetchdata shouldn't require .size()

parent 4000f5d5
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import multiprocessing import multiprocessing
from threading import Thread from threading import Thread
import itertools
from six.moves import range from six.moves import range
from six.moves.queue import Queue from six.moves.queue import Queue
import uuid import uuid
...@@ -52,7 +53,10 @@ class PrefetchData(ProxyDataFlow): ...@@ -52,7 +53,10 @@ class PrefetchData(ProxyDataFlow):
of data points will be random. of data points will be random.
""" """
super(PrefetchData, self).__init__(ds) super(PrefetchData, self).__init__(ds)
self._size = self.size() try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch self.nr_prefetch = nr_prefetch
self.queue = multiprocessing.Queue(self.nr_prefetch) self.queue = multiprocessing.Queue(self.nr_prefetch)
...@@ -63,7 +67,9 @@ class PrefetchData(ProxyDataFlow): ...@@ -63,7 +67,9 @@ class PrefetchData(ProxyDataFlow):
x.start() x.start()
def get_data(self): def get_data(self):
for _ in range(self._size): for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = self.queue.get() dp = self.queue.get()
yield dp yield dp
...@@ -96,7 +102,10 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -96,7 +102,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
of datapoints will be random. of datapoints will be random.
""" """
super(PrefetchDataZMQ, self).__init__(ds) super(PrefetchDataZMQ, self).__init__(ds)
try:
self._size = ds.size() self._size = ds.size()
except NotImplementedError:
self._size = -1
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.context = zmq.Context() self.context = zmq.Context()
...@@ -114,7 +123,9 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -114,7 +123,9 @@ class PrefetchDataZMQ(ProxyDataFlow):
atexit.register(lambda x: x.__del__(), self) atexit.register(lambda x: x.__del__(), self)
def get_data(self): def get_data(self):
for _ in range(self._size): for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = loads(self.socket.recv(copy=False)) dp = loads(self.socket.recv(copy=False))
yield dp yield dp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment