Commit fa8af3d8 authored by Yuxin Wu's avatar Yuxin Wu

refactor ZMQ dataflow (#414)

parent 6a4f4ee2
......@@ -13,7 +13,7 @@ import uuid
import os
import zmq
from .base import ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from .base import DataFlow, ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal,
StoppableThread)
......@@ -65,24 +65,53 @@ def _zmq_catch_error(name):
raise
class PrefetchProcess(mp.Process):
def __init__(self, ds, queue, reset_after_spawn=True):
class _MultiProcessZMQDataFlow(DataFlow):
def __init__(self, ds):
assert os.name != 'nt', "ZMQ IPC doesn't support windows!"
self._reset_done = False
self._procs = []
self.ds = ds
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
def size(self):
return self.ds.size()
def reset_state(self):
"""
:param ds: ds to take data from
:param queue: output queue to put results in
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
"""
super(PrefetchProcess, self).__init__()
self.ds = ds
self.queue = queue
self.reset_after_spawn = reset_after_spawn
if self._reset_done:
return
self._reset_done = True
def run(self):
if self.reset_after_spawn:
# reset all ds so each process will produce different data
self.ds.reset_state()
while True:
for dp in self.ds.get_data():
self.queue.put(dp)
# __del__ not guranteed to get called at exit
import atexit
atexit.register(lambda x: x.__del__(), self)
self._reset_once() # build processes
def _reset_once(self):
pass
def _start_processes(self):
start_proc_mask_signal(self._procs)
def __del__(self):
if not self._reset_done:
return
if not self.context.closed:
self.context.destroy(0)
for x in self._procs:
x.terminate()
try:
print("{} successfully cleaned-up.".format(type(self).__name__))
except:
pass
class PrefetchData(ProxyDataFlow):
......@@ -102,6 +131,20 @@ class PrefetchData(ProxyDataFlow):
This is different from the behavior of :class:`PrefetchDataZMQ`
4. `reset_state()` is a no-op. The worker processes won't get called.
"""
class _Worker(mp.Process):
def __init__(self, ds, queue):
super(PrefetchData._Worker, self).__init__()
self.ds = ds
self.queue = queue
def run(self):
# reset all ds so each process will produce different data
self.ds.reset_state()
while True:
for dp in self.ds.get_data():
self.queue.put(dp)
def __init__(self, ds, nr_prefetch, nr_proc):
"""
Args:
......@@ -119,7 +162,7 @@ class PrefetchData(ProxyDataFlow):
self._guard = DataFlowReentrantGuard()
self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue)
self.procs = [PrefetchData._Worker(self.ds, self.queue)
for _ in range(self.nr_proc)]
ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs)
......@@ -137,34 +180,12 @@ class PrefetchData(ProxyDataFlow):
pass
class PrefetchProcessZMQ(mp.Process):
def __init__(self, ds, conn_name, hwm):
super(PrefetchProcessZMQ, self).__init__()
self.ds = ds
self.conn_name = conn_name
self.hwm = hwm
def run(self):
self.ds.reset_state()
context = zmq.Context()
socket = context.socket(zmq.PUSH)
socket.set_hwm(self.hwm)
socket.connect(self.conn_name)
try:
while True:
for dp in self.ds.get_data():
socket.send(dumps(dp), copy=False)
# sigint could still propagate here, e.g. when nested
except KeyboardInterrupt:
pass
class PrefetchDataZMQ(ProxyDataFlow):
class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
"""
Prefetch data from a DataFlow using multiple processes, with ZeroMQ for
communication.
It will fork the process calling :meth:`reset_state()`,
collect datapoints from `ds` in each process by ZeroMQ IPC pipe.
It will fork the calling process of :meth:`reset_state()`,
and collect datapoints from `ds` in each process by ZeroMQ IPC pipe.
Note:
1. An iterator cannot run faster automatically -- what's happenning is
......@@ -194,6 +215,28 @@ class PrefetchDataZMQ(ProxyDataFlow):
which points to a local directory.
5. Calling `reset_state()` more than once is a no-op, i.e. the worker processes won't get called.
"""
class _Worker(mp.Process):
def __init__(self, ds, conn_name, hwm):
super(PrefetchDataZMQ._Worker, self).__init__()
self.ds = ds
self.conn_name = conn_name
self.hwm = hwm
def run(self):
self.ds.reset_state()
context = zmq.Context()
socket = context.socket(zmq.PUSH)
socket.set_hwm(self.hwm)
socket.connect(self.conn_name)
try:
while True:
for dp in self.ds.get_data():
socket.send(dumps(dp), copy=False)
# sigint could still propagate here, e.g. when nested
except KeyboardInterrupt:
pass
def __init__(self, ds, nr_proc=1, hwm=50):
"""
Args:
......@@ -201,12 +244,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
nr_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver.
"""
assert os.name != 'nt', "PrefetchDataZMQ doesn't support windows! PrefetchData might work sometimes."
super(PrefetchDataZMQ, self).__init__(ds)
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
self.nr_proc = nr_proc
self._hwm = hwm
......@@ -223,42 +262,16 @@ class PrefetchDataZMQ(ProxyDataFlow):
break
yield self._recv()
def reset_state(self):
"""
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
"""
if self._reset_done:
return
self._reset_done = True
def _reset_once(self):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(self._hwm)
pipename = _get_pipe_name('dataflow')
_bind_guard(self.socket, pipename)
self.procs = [PrefetchProcessZMQ(self.ds, pipename, self._hwm)
for _ in range(self.nr_proc)]
self._procs = [PrefetchDataZMQ._Worker(self.ds, pipename, self._hwm)
for _ in range(self.nr_proc)]
self._start_processes()
# __del__ not guranteed to get called at exit
import atexit
atexit.register(lambda x: x.__del__(), self)
def _start_processes(self):
start_proc_mask_signal(self.procs)
def __del__(self):
if not self._reset_done:
return
if not self.context.closed:
self.context.destroy(0)
for x in self.procs:
x.terminate()
try:
print("PrefetchDataZMQ successfully cleaned-up.")
except:
pass
class PrefetchOnGPUs(PrefetchDataZMQ):
......@@ -279,7 +292,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
def _start_processes(self):
with mask_sigint():
for gpu, proc in zip(self.gpus, self.procs):
for gpu, proc in zip(self.gpus, self._procs):
with change_gpu(gpu):
proc.start()
......@@ -307,14 +320,13 @@ class MultiThreadMapData(ProxyDataFlow):
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
"""
class _WorkerThread(StoppableThread):
def __init__(self, inq, outq, evt, map_func, strict):
super(MultiThreadMapData._WorkerThread, self).__init__(evt)
class _Worker(StoppableThread):
def __init__(self, inq, outq, evt, map_func):
super(MultiThreadMapData._Worker, self).__init__(evt)
self.inq = inq
self.outq = outq
self.func = map_func
self.daemon = True
self._strict = strict
def run(self):
try:
......@@ -322,13 +334,8 @@ class MultiThreadMapData(ProxyDataFlow):
dp = self.queue_get_stoppable(self.inq)
if self.stopped():
return
dp = self.func(dp)
if dp is not None:
self.outq.put(dp)
else:
assert not self._strict, \
"[MultiThreadMapData] Map function cannot return None when strict mode is used."
# cannot ignore None here. will lead to unsynced send/recv
self.outq.put(self.func(dp))
except:
if self.stopped():
pass # skip duplicated error messages
......@@ -348,7 +355,6 @@ class MultiThreadMapData(ProxyDataFlow):
"""
super(MultiThreadMapData, self).__init__(ds)
self._iter_ds = ds
self._strict = strict
self.nr_thread = nr_thread
self.buffer_size = buffer_size
......@@ -366,13 +372,13 @@ class MultiThreadMapData(ProxyDataFlow):
self._in_queue = queue.Queue()
self._out_queue = queue.Queue()
self._evt = threading.Event()
self._threads = [MultiThreadMapData._WorkerThread(
self._in_queue, self._out_queue, self._evt, self.map_func, self._strict)
self._threads = [MultiThreadMapData._Worker(
self._in_queue, self._out_queue, self._evt, self.map_func)
for _ in range(self.nr_thread)]
for t in self._threads:
t.start()
self._iter = self._iter_ds.get_data()
self._iter = self.ds.get_data()
self._guard = DataFlowReentrantGuard()
# only call once, to ensure inq+outq has a total of buffer_size elements
......@@ -390,24 +396,31 @@ class MultiThreadMapData(ProxyDataFlow):
logger.error("[MultiThreadMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def _recv(self):
ret = self._out_queue.get()
if ret is None:
assert not self._strict, \
"[MultiThreadMapData] Map function cannot return None when strict mode is used."
return ret
def get_data(self):
with self._guard:
for dp in self._iter:
self._in_queue.put(dp)
yield self._out_queue.get()
yield self._recv()
self._iter = self._iter_ds.get_data()
self._iter = self.ds.get_data()
if self._strict:
# first call get() to clear the queues, then fill
for k in range(self.buffer_size):
dp = self._out_queue.get()
dp = self._recv()
if k == self.buffer_size - 1:
self._fill_buffer()
yield dp
else:
for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter))
yield self._out_queue.get()
yield self._recv()
def __del__(self):
if self._evt is not None:
......@@ -420,7 +433,11 @@ class MultiThreadMapData(ProxyDataFlow):
ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(ProxyDataFlow):
class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
"""
Same as :class:`MapData`, but start processes to run the mapping function,
and communicate with ZeroMQ pipe.
"""
class _Worker(mp.Process):
def __init__(self, identity, map_func, pipename, hwm):
super(MultiProcessMapDataZMQ._Worker, self).__init__()
......@@ -442,57 +459,50 @@ class MultiProcessMapDataZMQ(ProxyDataFlow):
socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200):
"""
Args:
ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None
buffer_size (int): number of datapoints in the buffer
"""
super(MultiProcessMapDataZMQ, self).__init__(ds)
self.nr_proc = nr_proc
self.map_func = map_func
self._buffer_size = buffer_size
self.buffer_size = buffer_size
self._procs = []
self._reset_done = False
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
def reset_state(self):
if self._reset_done:
return
self._reset_done = True
self._guard = DataFlowReentrantGuard()
def _reset_once(self):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.set_hwm(self._buffer_size * 2)
self.socket.set_hwm(self.buffer_size * 2)
pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)]
worker_hwm = int(self._buffer_size * 2 // self.nr_proc)
worker_hwm = int(self.buffer_size * 2 // self.nr_proc)
self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)]
self.ds.reset_state()
self._iter_ds = _repeat_iter(lambda: self.ds.get_data())
self._iter = self.ds.get_data()
self._iter_worker = _repeat_iter(lambda: iter(self._proc_ids))
self._guard = DataFlowReentrantGuard()
self._start_processes()
self._fill_buffer()
import atexit
atexit.register(lambda x: x.__del__(), self)
def _fill_buffer(self):
# Filling the buffer.
for _ in range(self._buffer_size):
self._send()
def _start_processes(self):
start_proc_mask_signal(self._procs)
try:
for _ in range(self.buffer_size):
self._send(next(self._iter))
except StopIteration:
logger.error("[MultiProcessMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def _send(self):
dp = next(self._iter_ds)
def _send(self, dp):
# round-robin assignment
worker = next(self._iter_worker)
msg = [worker, dumps(dp)]
......@@ -505,40 +515,32 @@ class MultiProcessMapDataZMQ(ProxyDataFlow):
def get_data(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
for dp in self._iter:
self._send(dp)
yield self._recv()
self._send()
def __del__(self):
if not self._reset_done:
return
if not self.context.closed:
self.context.destroy(0)
for x in self._procs:
x.terminate()
try:
print("MultiProcessMapData successfully cleaned-up.")
except:
pass
self._iter = self.ds.get_data() # refresh
for _ in range(self.buffer_size):
self._send(next(self._iter))
yield self._recv()
MultiProcessMapData = MultiProcessMapDataZMQ # alias
if __name__ == '__main__':
from .base import DataFlow
class Zero(DataFlow):
def __init__(self, size):
self._size = size
class Naive(DataFlow):
def get_data(self):
for k in range(1000):
for k in range(self._size):
yield [0]
def size(self):
return 100
return self._size
ds = Naive()
ds = Zero(300)
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1])
ds.reset_state()
for k in ds.get_data():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment