Commit 0fa90c41 authored by Yuxin Wu's avatar Yuxin Wu

ParallelMap: use infinite iterator when strict=False

parent 0a5739fa
...@@ -95,21 +95,15 @@ class _MultiProcessZMQDataFlow(DataFlow): ...@@ -95,21 +95,15 @@ class _MultiProcessZMQDataFlow(DataFlow):
def reset_state(self): def reset_state(self):
""" """
All forked dataflows are reset **once and only once** in spawned processes. All forked dataflows should only be reset **once and only once** in spawned processes.
Nothing more can be done when calling this method. Subclasses should call this method with super.
""" """
if self._reset_done: assert not self._reset_done, "reset_state() was called twice! This violates the API of DataFlow!"
return
self._reset_done = True self._reset_done = True
# __del__ not guaranteed to get called at exit # __del__ not guaranteed to get called at exit
atexit.register(del_weakref, weakref.ref(self)) atexit.register(del_weakref, weakref.ref(self))
self._reset_once() # build processes
def _reset_once(self):
pass
def _start_processes(self): def _start_processes(self):
start_proc_mask_signal(self._procs) start_proc_mask_signal(self._procs)
...@@ -315,7 +309,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -315,7 +309,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
break break
yield self._recv() yield self._recv()
def _reset_once(self): def reset_state(self):
super(PrefetchDataZMQ, self).reset_state()
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL) self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(self._hwm) self.socket.set_hwm(self._hwm)
...@@ -400,7 +395,7 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -400,7 +395,7 @@ class MultiThreadPrefetchData(DataFlow):
th.start() th.start()
def __len__(self): def __len__(self):
return self.threads[0].__len__() return self.threads[0].df.__len__()
def __iter__(self): def __iter__(self):
while True: while True:
...@@ -463,3 +458,16 @@ plasma = None ...@@ -463,3 +458,16 @@ plasma = None
# from ..utils.develop import create_dummy_class # from ..utils.develop import create_dummy_class
# PlasmaPutData = create_dummy_class('PlasmaPutData', 'pyarrow') # noqa # PlasmaPutData = create_dummy_class('PlasmaPutData', 'pyarrow') # noqa
# PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa # PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa
if __name__ == '__main__':
import time
from .raw import DataFromGenerator
from .common import FixedSizeData
x = DataFromGenerator(itertools.count())
x = FixedSizeData(x, 100)
x = PrefetchDataZMQ(x, 2)
x.reset_state()
for idx, dp in enumerate(x):
print(dp)
time.sleep(0.1)
...@@ -9,6 +9,7 @@ from six.moves import queue ...@@ -9,6 +9,7 @@ from six.moves import queue
import zmq import zmq
from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard
from .common import RepeatedData
from ..utils.concurrency import StoppableThread, enable_death_signal from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils import logger from ..utils import logger
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
...@@ -23,11 +24,18 @@ __all__ = ['ThreadedMapData', 'MultiThreadMapData', ...@@ -23,11 +24,18 @@ __all__ = ['ThreadedMapData', 'MultiThreadMapData',
class _ParallelMapData(ProxyDataFlow): class _ParallelMapData(ProxyDataFlow):
def __init__(self, ds, buffer_size): def __init__(self, ds, buffer_size, strict=False):
if not strict:
ds = RepeatedData(ds, -1)
super(_ParallelMapData, self).__init__(ds) super(_ParallelMapData, self).__init__(ds)
assert buffer_size > 0, buffer_size assert buffer_size > 0, buffer_size
self._buffer_size = buffer_size self._buffer_size = buffer_size
self._buffer_occupancy = 0 # actual #elements in buffer self._buffer_occupancy = 0 # actual #elements in buffer, only useful in strict mode
self._strict = strict
def reset_state(self):
super(_ParallelMapData, self).reset_state()
self._iter = self.ds.__iter__()
def _recv(self): def _recv(self):
pass pass
...@@ -50,7 +58,8 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -50,7 +58,8 @@ class _ParallelMapData(ProxyDataFlow):
self._send(dp) self._send(dp)
except StopIteration: except StopIteration:
logger.error( logger.error(
"[{}] buffer_size cannot be larger than the size of the DataFlow!".format(type(self).__name__)) "[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!".format(
type(self).__name__))
raise raise
self._buffer_occupancy += cnt self._buffer_occupancy += cnt
...@@ -61,13 +70,6 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -61,13 +70,6 @@ class _ParallelMapData(ProxyDataFlow):
if ret is not None: if ret is not None:
yield ret yield ret
self._iter = self.ds.__iter__() # refresh
for _ in range(self._buffer_size):
self._send(next(self._iter))
ret = self._recv()
if ret is not None:
yield ret
def get_data_strict(self): def get_data_strict(self):
self._fill_buffer() self._fill_buffer()
for dp in self._iter: for dp in self._iter:
...@@ -83,6 +85,14 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -83,6 +85,14 @@ class _ParallelMapData(ProxyDataFlow):
self._fill_buffer() self._fill_buffer()
yield dp yield dp
def __iter__(self):
if self._strict:
for dp in self.get_data_strict():
yield dp
else:
for dp in self.get_data_non_strict():
yield dp
class MultiThreadMapData(_ParallelMapData): class MultiThreadMapData(_ParallelMapData):
""" """
...@@ -141,7 +151,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -141,7 +151,7 @@ class MultiThreadMapData(_ParallelMapData):
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
""" """
super(MultiThreadMapData, self).__init__(ds, buffer_size) super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
self._strict = strict self._strict = strict
self.nr_thread = nr_thread self.nr_thread = nr_thread
...@@ -165,7 +175,6 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -165,7 +175,6 @@ class MultiThreadMapData(_ParallelMapData):
for t in self._threads: for t in self._threads:
t.start() t.start()
self._iter = self.ds.__iter__()
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
# Call once at the beginning, to ensure inq+outq has a total of buffer_size elements # Call once at the beginning, to ensure inq+outq has a total of buffer_size elements
...@@ -179,11 +188,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -179,11 +188,7 @@ class MultiThreadMapData(_ParallelMapData):
def __iter__(self): def __iter__(self):
with self._guard: with self._guard:
if self._strict: for dp in super(MultiThreadMapData, self).__iter__():
for dp in self.get_data_strict():
yield dp
else:
for dp in self.get_data_non_strict():
yield dp yield dp
def __del__(self): def __del__(self):
...@@ -245,7 +250,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -245,7 +250,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
""" """
_ParallelMapData.__init__(self, ds, buffer_size) _ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self) _MultiProcessZMQDataFlow.__init__(self)
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.map_func = map_func self.map_func = map_func
...@@ -253,7 +258,10 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -253,7 +258,10 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self._procs = [] self._procs = []
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
def _reset_once(self): def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
_ParallelMapData.reset_state(self)
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER) self.socket = self.context.socket(zmq.DEALER)
self.socket.set_hwm(self._buffer_size * 2) self.socket.set_hwm(self._buffer_size * 2)
...@@ -266,15 +274,9 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -266,15 +274,9 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self._proc_ids[k], self.map_func, pipename, worker_hwm) self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)] for k in range(self.nr_proc)]
self.ds.reset_state()
self._iter = self.ds.__iter__()
self._start_processes() self._start_processes()
self._fill_buffer() # pre-fill the bufer self._fill_buffer() # pre-fill the bufer
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
def _send(self, dp): def _send(self, dp):
msg = [b"", dumps(dp)] msg = [b"", dumps(dp)]
self.socket.send_multipart(msg, copy=False) self.socket.send_multipart(msg, copy=False)
...@@ -286,11 +288,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -286,11 +288,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
def __iter__(self): def __iter__(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'): with self._guard, _zmq_catch_error('MultiProcessMapData'):
if self._strict: for dp in super(MultiProcessMapDataZMQ, self).__iter__():
for dp in self.get_data_strict():
yield dp
else:
for dp in self.get_data_non_strict():
yield dp yield dp
...@@ -388,6 +386,8 @@ class MultiProcessMapDataComponentSharedArray(DataFlow): ...@@ -388,6 +386,8 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
if __name__ == '__main__': if __name__ == '__main__':
import time
class Zero(DataFlow): class Zero(DataFlow):
def __init__(self, size): def __init__(self, size):
self._size = size self._size = size
...@@ -399,8 +399,13 @@ if __name__ == '__main__': ...@@ -399,8 +399,13 @@ if __name__ == '__main__':
def __len__(self): def __len__(self):
return self._size return self._size
ds = Zero(300) def f(x):
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1], strict=True) if x[0] < 10:
time.sleep(1)
return x
ds = Zero(100)
ds = MultiThreadMapData(ds, 50, f, buffer_size=50, strict=False)
ds.reset_state() ds.reset_state()
for k in ds: for k in ds:
print("Bang!", k) print("Bang!", k)
......
...@@ -187,8 +187,6 @@ class EnqueueThread(ShareSessionThread): ...@@ -187,8 +187,6 @@ class EnqueueThread(ShareSessionThread):
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
""" Enqueue datapoints from a DataFlow to a TF queue. """ Enqueue datapoints from a DataFlow to a TF queue.
And the model receives dequeued tensors. And the model receives dequeued tensors.
Calling :meth:`reset_state()` will clear the queue and reset the dataflow.
""" """
def __init__(self, ds, queue=None): def __init__(self, ds, queue=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment