Commit 6a4f4ee2 authored by Yuxin Wu's avatar Yuxin Wu

initial version of MultiProcessMapData (#414)

parent 9eaf6e92
......@@ -4,6 +4,7 @@
from __future__ import print_function
import threading
from contextlib import contextmanager
import multiprocessing as mp
import itertools
from six.moves import range, zip, queue
......@@ -21,7 +22,47 @@ from ..utils import logger
from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs',
'ThreadedMapData']
'ThreadedMapData', 'MultiThreadMapData', 'MultiProcessMapData']
def _repeat_iter(get_itr):
while True:
for x in get_itr():
yield x
def _bind_guard(sock, name):
try:
sock.bind(name)
except zmq.ZMQError:
logger.error(
"ZMQError in socket.bind(). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information.")
raise
def _get_pipe_name(name):
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
pipename = "ipc://{}/{}-pipe-".format(pipedir.rstrip('/'), name) + str(uuid.uuid1())[:6]
return pipename
@contextmanager
def _zmq_catch_error(name):
try:
yield
except zmq.ContextTerminated:
logger.info("[{}] Context terminated.".format(name))
raise DataFlowTerminated()
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[{}] Socket closed.".format(name))
raise DataFlowTerminated()
else:
raise
except:
raise
class PrefetchProcess(mp.Process):
......@@ -105,14 +146,14 @@ class PrefetchProcessZMQ(mp.Process):
def run(self):
self.ds.reset_state()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(self.hwm)
self.socket.connect(self.conn_name)
context = zmq.Context()
socket = context.socket(zmq.PUSH)
socket.set_hwm(self.hwm)
socket.connect(self.conn_name)
try:
while True:
for dp in self.ds.get_data():
self.socket.send(dumps(dp), copy=False)
socket.send(dumps(dp), copy=False)
# sigint could still propagate here, e.g. when nested
except KeyboardInterrupt:
pass
......@@ -170,53 +211,34 @@ class PrefetchDataZMQ(ProxyDataFlow):
self._hwm = hwm
self._guard = DataFlowReentrantGuard()
self._setup_done = False
self._reset_done = False
def _recv(self):
return loads(self.socket.recv(copy=False).bytes)
def get_data(self):
with self._guard:
try:
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = loads(self.socket.recv(copy=False).bytes)
yield dp
except zmq.ContextTerminated:
logger.info("[Prefetch Master] Context terminated.")
raise DataFlowTerminated()
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[Prefetch Master] Socket closed.")
raise DataFlowTerminated()
else:
raise
except:
raise
with self._guard, _zmq_catch_error('PrefetchDataZMQ'):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
yield self._recv()
def reset_state(self):
"""
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
"""
if self._setup_done:
if self._reset_done:
return
self._setup_done = True
self._reset_done = True
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(self._hwm)
try:
self.socket.bind(self.pipename)
except zmq.ZMQError:
logger.error(
"ZMQError in socket.bind(). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information.")
raise
pipename = _get_pipe_name('dataflow')
_bind_guard(self.socket, pipename)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename, self._hwm)
self.procs = [PrefetchProcessZMQ(self.ds, pipename, self._hwm)
for _ in range(self.nr_proc)]
self._start_processes()
# __del__ not guranteed to get called at exit
......@@ -227,15 +249,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
start_proc_mask_signal(self.procs)
def __del__(self):
if not self._setup_done:
if not self._reset_done:
return
if not self.context.closed:
self.context.destroy(0)
for x in self.procs:
x.terminate()
try:
# TODO test if logger here would overwrite log file
print("Prefetch process exited.")
print("PrefetchDataZMQ successfully cleaned-up.")
except:
pass
......@@ -263,7 +284,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
proc.start()
class ThreadedMapData(ProxyDataFlow):
class MultiThreadMapData(ProxyDataFlow):
"""
Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't
......@@ -274,7 +295,7 @@ class ThreadedMapData(ProxyDataFlow):
should avoid starting many threads in your main process to reduce GIL contention.
The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(ThreadedMapData(...), 1)``
Therefore you can use ``PrefetchDataZMQ(MultiThreadMapData(...), 1)``
to reduce GIL contention.
2. Threads run in parallel and can take different time to run the
......@@ -282,13 +303,13 @@ class ThreadedMapData(ProxyDataFlow):
preserved, and datapoints from one pass of `df.get_data()` might get
mixed with datapoints from the next pass.
You can use **strict mode**, where `ThreadedMapData.get_data()`
You can use **strict mode**, where `MultiThreadMapData.get_data()`
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
"""
class _WorkerThread(StoppableThread):
def __init__(self, inq, outq, evt, map_func, strict):
super(ThreadedMapData._WorkerThread, self).__init__(evt)
super(MultiThreadMapData._WorkerThread, self).__init__(evt)
self.inq = inq
self.outq = outq
self.func = map_func
......@@ -307,7 +328,7 @@ class ThreadedMapData(ProxyDataFlow):
self.outq.put(dp)
else:
assert not self._strict, \
"[ThreadedMapData] Map function cannot return None when strict mode is used."
"[MultiThreadMapData] Map function cannot return None when strict mode is used."
except:
if self.stopped():
pass # skip duplicated error messages
......@@ -325,7 +346,7 @@ class ThreadedMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super(ThreadedMapData, self).__init__(ds)
super(MultiThreadMapData, self).__init__(ds)
self._iter_ds = ds
self._strict = strict
......@@ -336,7 +357,7 @@ class ThreadedMapData(ProxyDataFlow):
self._evt = None
def reset_state(self):
super(ThreadedMapData, self).reset_state()
super(MultiThreadMapData, self).reset_state()
if self._threads:
self._threads[0].stop()
for t in self._threads:
......@@ -345,7 +366,7 @@ class ThreadedMapData(ProxyDataFlow):
self._in_queue = queue.Queue()
self._out_queue = queue.Queue()
self._evt = threading.Event()
self._threads = [ThreadedMapData._WorkerThread(
self._threads = [MultiThreadMapData._WorkerThread(
self._in_queue, self._out_queue, self._evt, self.map_func, self._strict)
for _ in range(self.nr_thread)]
for t in self._threads:
......@@ -366,7 +387,7 @@ class ThreadedMapData(ProxyDataFlow):
for _ in range(n):
self._in_queue.put(next(self._iter))
except StopIteration:
logger.error("[ThreadedMapData] buffer_size cannot be larger than the size of the DataFlow!")
logger.error("[MultiThreadMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def get_data(self):
......@@ -393,3 +414,133 @@ class ThreadedMapData(ProxyDataFlow):
self._evt.set()
for p in self._threads:
p.join()
# TODO deprecated
ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(ProxyDataFlow):
class _Worker(mp.Process):
def __init__(self, identity, map_func, pipename, hwm):
super(MultiProcessMapDataZMQ._Worker, self).__init__()
self.identity = identity
self.map_func = map_func
self.pipename = pipename
self.hwm = hwm
def run(self):
ctx = zmq.Context()
socket = ctx.socket(zmq.DEALER)
socket.setsockopt(zmq.IDENTITY, self.identity)
socket.set_hwm(self.hwm)
socket.connect(self.pipename)
while True:
dp = loads(socket.recv(copy=False).bytes)
dp = self.map_func(dp)
socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200):
super(MultiProcessMapDataZMQ, self).__init__(ds)
self.nr_proc = nr_proc
self.map_func = map_func
self._buffer_size = buffer_size
self._procs = []
self._reset_done = False
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
def reset_state(self):
if self._reset_done:
return
self._reset_done = True
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.set_hwm(self._buffer_size * 2)
pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)]
worker_hwm = int(self._buffer_size * 2 // self.nr_proc)
self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)]
self.ds.reset_state()
self._iter_ds = _repeat_iter(lambda: self.ds.get_data())
self._iter_worker = _repeat_iter(lambda: iter(self._proc_ids))
self._guard = DataFlowReentrantGuard()
self._start_processes()
self._fill_buffer()
import atexit
atexit.register(lambda x: x.__del__(), self)
def _fill_buffer(self):
# Filling the buffer.
for _ in range(self._buffer_size):
self._send()
def _start_processes(self):
start_proc_mask_signal(self._procs)
def _send(self):
dp = next(self._iter_ds)
# round-robin assignment
worker = next(self._iter_worker)
msg = [worker, dumps(dp)]
self.socket.send_multipart(msg, copy=False)
def _recv(self):
msg = self.socket.recv_multipart(copy=False)
dp = loads(msg[1].bytes)
return dp
def get_data(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
yield self._recv()
self._send()
def __del__(self):
if not self._reset_done:
return
if not self.context.closed:
self.context.destroy(0)
for x in self._procs:
x.terminate()
try:
print("MultiProcessMapData successfully cleaned-up.")
except:
pass
MultiProcessMapData = MultiProcessMapDataZMQ # alias
if __name__ == '__main__':
from .base import DataFlow
class Naive(DataFlow):
def get_data(self):
for k in range(1000):
yield [0]
def size(self):
return 100
ds = Naive()
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1])
ds.reset_state()
for k in ds.get_data():
print("Bang!", k)
print("END!")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment