Commit 0fa90c41 authored by Yuxin Wu's avatar Yuxin Wu

ParallelMap: use infinite iterator when strict=False

parent 0a5739fa
......@@ -95,21 +95,15 @@ class _MultiProcessZMQDataFlow(DataFlow):
def reset_state(self):
"""
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
All forked dataflows should only be reset **once and only once** in spawned processes.
Subclasses should call this method with super.
"""
if self._reset_done:
return
assert not self._reset_done, "reset_state() was called twice! This violates the API of DataFlow!"
self._reset_done = True
# __del__ not guaranteed to get called at exit
atexit.register(del_weakref, weakref.ref(self))
self._reset_once() # build processes
def _reset_once(self):
pass
def _start_processes(self):
start_proc_mask_signal(self._procs)
......@@ -315,7 +309,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
break
yield self._recv()
def _reset_once(self):
def reset_state(self):
super(PrefetchDataZMQ, self).reset_state()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(self._hwm)
......@@ -400,7 +395,7 @@ class MultiThreadPrefetchData(DataFlow):
th.start()
def __len__(self):
return self.threads[0].__len__()
return self.threads[0].df.__len__()
def __iter__(self):
while True:
......@@ -463,3 +458,16 @@ plasma = None
# from ..utils.develop import create_dummy_class
# PlasmaPutData = create_dummy_class('PlasmaPutData', 'pyarrow') # noqa
# PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa
if __name__ == '__main__':
import time
from .raw import DataFromGenerator
from .common import FixedSizeData
x = DataFromGenerator(itertools.count())
x = FixedSizeData(x, 100)
x = PrefetchDataZMQ(x, 2)
x.reset_state()
for idx, dp in enumerate(x):
print(dp)
time.sleep(0.1)
......@@ -9,6 +9,7 @@ from six.moves import queue
import zmq
from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard
from .common import RepeatedData
from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils import logger
from ..utils.serialize import loads, dumps
......@@ -23,11 +24,18 @@ __all__ = ['ThreadedMapData', 'MultiThreadMapData',
class _ParallelMapData(ProxyDataFlow):
def __init__(self, ds, buffer_size):
def __init__(self, ds, buffer_size, strict=False):
if not strict:
ds = RepeatedData(ds, -1)
super(_ParallelMapData, self).__init__(ds)
assert buffer_size > 0, buffer_size
self._buffer_size = buffer_size
self._buffer_occupancy = 0 # actual #elements in buffer
self._buffer_occupancy = 0 # actual #elements in buffer, only useful in strict mode
self._strict = strict
def reset_state(self):
super(_ParallelMapData, self).reset_state()
self._iter = self.ds.__iter__()
def _recv(self):
pass
......@@ -50,7 +58,8 @@ class _ParallelMapData(ProxyDataFlow):
self._send(dp)
except StopIteration:
logger.error(
"[{}] buffer_size cannot be larger than the size of the DataFlow!".format(type(self).__name__))
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!".format(
type(self).__name__))
raise
self._buffer_occupancy += cnt
......@@ -61,13 +70,6 @@ class _ParallelMapData(ProxyDataFlow):
if ret is not None:
yield ret
self._iter = self.ds.__iter__() # refresh
for _ in range(self._buffer_size):
self._send(next(self._iter))
ret = self._recv()
if ret is not None:
yield ret
def get_data_strict(self):
self._fill_buffer()
for dp in self._iter:
......@@ -83,6 +85,14 @@ class _ParallelMapData(ProxyDataFlow):
self._fill_buffer()
yield dp
def __iter__(self):
if self._strict:
for dp in self.get_data_strict():
yield dp
else:
for dp in self.get_data_non_strict():
yield dp
class MultiThreadMapData(_ParallelMapData):
"""
......@@ -141,7 +151,7 @@ class MultiThreadMapData(_ParallelMapData):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super(MultiThreadMapData, self).__init__(ds, buffer_size)
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
self._strict = strict
self.nr_thread = nr_thread
......@@ -165,7 +175,6 @@ class MultiThreadMapData(_ParallelMapData):
for t in self._threads:
t.start()
self._iter = self.ds.__iter__()
self._guard = DataFlowReentrantGuard()
# Call once at the beginning, to ensure inq+outq has a total of buffer_size elements
......@@ -179,11 +188,7 @@ class MultiThreadMapData(_ParallelMapData):
def __iter__(self):
with self._guard:
if self._strict:
for dp in self.get_data_strict():
yield dp
else:
for dp in self.get_data_non_strict():
for dp in super(MultiThreadMapData, self).__iter__():
yield dp
def __del__(self):
......@@ -245,7 +250,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
_ParallelMapData.__init__(self, ds, buffer_size)
_ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self)
self.nr_proc = nr_proc
self.map_func = map_func
......@@ -253,7 +258,10 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self._procs = []
self._guard = DataFlowReentrantGuard()
def _reset_once(self):
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
_ParallelMapData.reset_state(self)
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
self.socket.set_hwm(self._buffer_size * 2)
......@@ -266,15 +274,9 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)]
self.ds.reset_state()
self._iter = self.ds.__iter__()
self._start_processes()
self._fill_buffer() # pre-fill the bufer
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
def _send(self, dp):
msg = [b"", dumps(dp)]
self.socket.send_multipart(msg, copy=False)
......@@ -286,11 +288,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
def __iter__(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'):
if self._strict:
for dp in self.get_data_strict():
yield dp
else:
for dp in self.get_data_non_strict():
for dp in super(MultiProcessMapDataZMQ, self).__iter__():
yield dp
......@@ -388,6 +386,8 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
if __name__ == '__main__':
import time
class Zero(DataFlow):
def __init__(self, size):
self._size = size
......@@ -399,8 +399,13 @@ if __name__ == '__main__':
def __len__(self):
return self._size
ds = Zero(300)
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1], strict=True)
def f(x):
if x[0] < 10:
time.sleep(1)
return x
ds = Zero(100)
ds = MultiThreadMapData(ds, 50, f, buffer_size=50, strict=False)
ds.reset_state()
for k in ds:
print("Bang!", k)
......
......@@ -187,8 +187,6 @@ class EnqueueThread(ShareSessionThread):
class QueueInput(FeedfreeInput):
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives dequeued tensors.
Calling :meth:`reset_state()` will clear the queue and reset the dataflow.
"""
def __init__(self, ds, queue=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment