Commit 6a4f4ee2 authored by Yuxin Wu's avatar Yuxin Wu

initial version of MultiProcessMapData (#414)

parent 9eaf6e92
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from __future__ import print_function from __future__ import print_function
import threading import threading
from contextlib import contextmanager
import multiprocessing as mp import multiprocessing as mp
import itertools import itertools
from six.moves import range, zip, queue from six.moves import range, zip, queue
...@@ -21,7 +22,47 @@ from ..utils import logger ...@@ -21,7 +22,47 @@ from ..utils import logger
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs', __all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs',
'ThreadedMapData'] 'ThreadedMapData', 'MultiThreadMapData', 'MultiProcessMapData']
def _repeat_iter(get_itr):
while True:
for x in get_itr():
yield x
def _bind_guard(sock, name):
try:
sock.bind(name)
except zmq.ZMQError:
logger.error(
"ZMQError in socket.bind(). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information.")
raise
def _get_pipe_name(name):
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
pipename = "ipc://{}/{}-pipe-".format(pipedir.rstrip('/'), name) + str(uuid.uuid1())[:6]
return pipename
@contextmanager
def _zmq_catch_error(name):
try:
yield
except zmq.ContextTerminated:
logger.info("[{}] Context terminated.".format(name))
raise DataFlowTerminated()
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[{}] Socket closed.".format(name))
raise DataFlowTerminated()
else:
raise
except:
raise
class PrefetchProcess(mp.Process): class PrefetchProcess(mp.Process):
...@@ -105,14 +146,14 @@ class PrefetchProcessZMQ(mp.Process): ...@@ -105,14 +146,14 @@ class PrefetchProcessZMQ(mp.Process):
def run(self): def run(self):
self.ds.reset_state() self.ds.reset_state()
self.context = zmq.Context() context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH) socket = context.socket(zmq.PUSH)
self.socket.set_hwm(self.hwm) socket.set_hwm(self.hwm)
self.socket.connect(self.conn_name) socket.connect(self.conn_name)
try: try:
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
self.socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
# sigint could still propagate here, e.g. when nested # sigint could still propagate here, e.g. when nested
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
...@@ -170,53 +211,34 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -170,53 +211,34 @@ class PrefetchDataZMQ(ProxyDataFlow):
self._hwm = hwm self._hwm = hwm
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
self._setup_done = False self._reset_done = False
def _recv(self):
return loads(self.socket.recv(copy=False).bytes)
def get_data(self): def get_data(self):
with self._guard: with self._guard, _zmq_catch_error('PrefetchDataZMQ'):
try: for k in itertools.count():
for k in itertools.count(): if self._size > 0 and k >= self._size:
if self._size > 0 and k >= self._size: break
break yield self._recv()
dp = loads(self.socket.recv(copy=False).bytes)
yield dp
except zmq.ContextTerminated:
logger.info("[Prefetch Master] Context terminated.")
raise DataFlowTerminated()
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[Prefetch Master] Socket closed.")
raise DataFlowTerminated()
else:
raise
except:
raise
def reset_state(self): def reset_state(self):
""" """
All forked dataflows are reset **once and only once** in spawned processes. All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method. Nothing more can be done when calling this method.
""" """
if self._setup_done: if self._reset_done:
return return
self._setup_done = True self._reset_done = True
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL) self.socket = self.context.socket(zmq.PULL)
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(self._hwm) self.socket.set_hwm(self._hwm)
try: pipename = _get_pipe_name('dataflow')
self.socket.bind(self.pipename) _bind_guard(self.socket, pipename)
except zmq.ZMQError:
logger.error(
"ZMQError in socket.bind(). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information.")
raise
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename, self._hwm) self.procs = [PrefetchProcessZMQ(self.ds, pipename, self._hwm)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
self._start_processes() self._start_processes()
# __del__ not guranteed to get called at exit # __del__ not guranteed to get called at exit
...@@ -227,15 +249,14 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -227,15 +249,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
start_proc_mask_signal(self.procs) start_proc_mask_signal(self.procs)
def __del__(self): def __del__(self):
if not self._setup_done: if not self._reset_done:
return return
if not self.context.closed: if not self.context.closed:
self.context.destroy(0) self.context.destroy(0)
for x in self.procs: for x in self.procs:
x.terminate() x.terminate()
try: try:
# TODO test if logger here would overwrite log file print("PrefetchDataZMQ successfully cleaned-up.")
print("Prefetch process exited.")
except: except:
pass pass
...@@ -263,7 +284,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ): ...@@ -263,7 +284,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
proc.start() proc.start()
class ThreadedMapData(ProxyDataFlow): class MultiThreadMapData(ProxyDataFlow):
""" """
Same as :class:`MapData`, but start threads to run the mapping function. Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't This is useful when the mapping function is the bottleneck, but you don't
...@@ -274,7 +295,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -274,7 +295,7 @@ class ThreadedMapData(ProxyDataFlow):
should avoid starting many threads in your main process to reduce GIL contention. should avoid starting many threads in your main process to reduce GIL contention.
The threads will only start in the process which calls :meth:`reset_state()`. The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(ThreadedMapData(...), 1)`` Therefore you can use ``PrefetchDataZMQ(MultiThreadMapData(...), 1)``
to reduce GIL contention. to reduce GIL contention.
2. Threads run in parallel and can take different time to run the 2. Threads run in parallel and can take different time to run the
...@@ -282,13 +303,13 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -282,13 +303,13 @@ class ThreadedMapData(ProxyDataFlow):
preserved, and datapoints from one pass of `df.get_data()` might get preserved, and datapoints from one pass of `df.get_data()` might get
mixed with datapoints from the next pass. mixed with datapoints from the next pass.
You can use **strict mode**, where `ThreadedMapData.get_data()` You can use **strict mode**, where `MultiThreadMapData.get_data()`
is guranteed to produce the exact set which `df.get_data()` is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved. produces. Although the order of data still isn't preserved.
""" """
class _WorkerThread(StoppableThread): class _WorkerThread(StoppableThread):
def __init__(self, inq, outq, evt, map_func, strict): def __init__(self, inq, outq, evt, map_func, strict):
super(ThreadedMapData._WorkerThread, self).__init__(evt) super(MultiThreadMapData._WorkerThread, self).__init__(evt)
self.inq = inq self.inq = inq
self.outq = outq self.outq = outq
self.func = map_func self.func = map_func
...@@ -307,7 +328,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -307,7 +328,7 @@ class ThreadedMapData(ProxyDataFlow):
self.outq.put(dp) self.outq.put(dp)
else: else:
assert not self._strict, \ assert not self._strict, \
"[ThreadedMapData] Map function cannot return None when strict mode is used." "[MultiThreadMapData] Map function cannot return None when strict mode is used."
except: except:
if self.stopped(): if self.stopped():
pass # skip duplicated error messages pass # skip duplicated error messages
...@@ -325,7 +346,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -325,7 +346,7 @@ class ThreadedMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
""" """
super(ThreadedMapData, self).__init__(ds) super(MultiThreadMapData, self).__init__(ds)
self._iter_ds = ds self._iter_ds = ds
self._strict = strict self._strict = strict
...@@ -336,7 +357,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -336,7 +357,7 @@ class ThreadedMapData(ProxyDataFlow):
self._evt = None self._evt = None
def reset_state(self): def reset_state(self):
super(ThreadedMapData, self).reset_state() super(MultiThreadMapData, self).reset_state()
if self._threads: if self._threads:
self._threads[0].stop() self._threads[0].stop()
for t in self._threads: for t in self._threads:
...@@ -345,7 +366,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -345,7 +366,7 @@ class ThreadedMapData(ProxyDataFlow):
self._in_queue = queue.Queue() self._in_queue = queue.Queue()
self._out_queue = queue.Queue() self._out_queue = queue.Queue()
self._evt = threading.Event() self._evt = threading.Event()
self._threads = [ThreadedMapData._WorkerThread( self._threads = [MultiThreadMapData._WorkerThread(
self._in_queue, self._out_queue, self._evt, self.map_func, self._strict) self._in_queue, self._out_queue, self._evt, self.map_func, self._strict)
for _ in range(self.nr_thread)] for _ in range(self.nr_thread)]
for t in self._threads: for t in self._threads:
...@@ -366,7 +387,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -366,7 +387,7 @@ class ThreadedMapData(ProxyDataFlow):
for _ in range(n): for _ in range(n):
self._in_queue.put(next(self._iter)) self._in_queue.put(next(self._iter))
except StopIteration: except StopIteration:
logger.error("[ThreadedMapData] buffer_size cannot be larger than the size of the DataFlow!") logger.error("[MultiThreadMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise raise
def get_data(self): def get_data(self):
...@@ -393,3 +414,133 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -393,3 +414,133 @@ class ThreadedMapData(ProxyDataFlow):
self._evt.set() self._evt.set()
for p in self._threads: for p in self._threads:
p.join() p.join()
# TODO deprecated
ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(ProxyDataFlow):
class _Worker(mp.Process):
def __init__(self, identity, map_func, pipename, hwm):
super(MultiProcessMapDataZMQ._Worker, self).__init__()
self.identity = identity
self.map_func = map_func
self.pipename = pipename
self.hwm = hwm
def run(self):
ctx = zmq.Context()
socket = ctx.socket(zmq.DEALER)
socket.setsockopt(zmq.IDENTITY, self.identity)
socket.set_hwm(self.hwm)
socket.connect(self.pipename)
while True:
dp = loads(socket.recv(copy=False).bytes)
dp = self.map_func(dp)
socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200):
super(MultiProcessMapDataZMQ, self).__init__(ds)
self.nr_proc = nr_proc
self.map_func = map_func
self._buffer_size = buffer_size
self._procs = []
self._reset_done = False
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
def reset_state(self):
if self._reset_done:
return
self._reset_done = True
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.set_hwm(self._buffer_size * 2)
pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)]
worker_hwm = int(self._buffer_size * 2 // self.nr_proc)
self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)]
self.ds.reset_state()
self._iter_ds = _repeat_iter(lambda: self.ds.get_data())
self._iter_worker = _repeat_iter(lambda: iter(self._proc_ids))
self._guard = DataFlowReentrantGuard()
self._start_processes()
self._fill_buffer()
import atexit
atexit.register(lambda x: x.__del__(), self)
def _fill_buffer(self):
# Filling the buffer.
for _ in range(self._buffer_size):
self._send()
def _start_processes(self):
start_proc_mask_signal(self._procs)
def _send(self):
dp = next(self._iter_ds)
# round-robin assignment
worker = next(self._iter_worker)
msg = [worker, dumps(dp)]
self.socket.send_multipart(msg, copy=False)
def _recv(self):
msg = self.socket.recv_multipart(copy=False)
dp = loads(msg[1].bytes)
return dp
def get_data(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
yield self._recv()
self._send()
def __del__(self):
if not self._reset_done:
return
if not self.context.closed:
self.context.destroy(0)
for x in self._procs:
x.terminate()
try:
print("MultiProcessMapData successfully cleaned-up.")
except:
pass
MultiProcessMapData = MultiProcessMapDataZMQ # alias
if __name__ == '__main__':
from .base import DataFlow
class Naive(DataFlow):
def get_data(self):
for k in range(1000):
yield [0]
def size(self):
return 100
ds = Naive()
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1])
ds.reset_state()
for k in ds.get_data():
print("Bang!", k)
print("END!")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment