Commit 1d3ab162 authored by Yuxin Wu's avatar Yuxin Wu

allow mask_sigint called in non-main threads. avoid #368 and enable nested PrefetchDataZMQ

parent 3d9dc7d0
......@@ -102,9 +102,13 @@ class PrefetchProcessZMQ(mp.Process):
self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(self.hwm)
self.socket.connect(self.conn_name)
while True:
for dp in self.ds.get_data():
self.socket.send(dumps(dp), copy=False)
try:
while True:
for dp in self.ds.get_data():
self.socket.send(dumps(dp), copy=False)
# sigint could still propagate here, e.g. when nested
except KeyboardInterrupt:
pass
class PrefetchDataZMQ(ProxyDataFlow):
......@@ -117,7 +121,9 @@ class PrefetchDataZMQ(ProxyDataFlow):
Note:
1. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe.
2. This dataflow is not fork-safe. You cannot nest it.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchDataZMQ(df, nr_proc=a), nr_proc=b)``.
A total of ``a * b`` instances of ``df`` worker processes will be created.
Also in this case some zmq pipes cannot be cleaned at exit.
3. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
......@@ -139,21 +145,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.nr_proc = nr_proc
self._hwm = hwm
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(self._hwm)
self.socket.bind(self.pipename)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename, self._hwm)
for _ in range(self.nr_proc)]
self.start_processes()
# __del__ not guranteed to get called at exit
import atexit
atexit.register(lambda x: x.__del__(), self)
self._setup_done = False
def get_data(self):
try:
......@@ -179,13 +171,32 @@ class PrefetchDataZMQ(ProxyDataFlow):
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
"""
pass
if self._setup_done:
return
self._setup_done = True
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(self._hwm)
self.socket.bind(self.pipename)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename, self._hwm)
for _ in range(self.nr_proc)]
self.start_processes()
# __del__ not guranteed to get called at exit
import atexit
atexit.register(lambda x: x.__del__(), self)
def start_processes(self):
start_proc_mask_signal(self.procs)
def __del__(self):
# on exit, logger may not be functional anymore
if not self._setup_done:
return
if not self.context.closed:
self.context.destroy(0)
for x in self.procs:
......@@ -226,12 +237,12 @@ class ThreadedMapData(ProxyDataFlow):
This is useful when the mapping function is the bottleneck, but you don't
want to start processes for the entire dataflow pipeline.
Notes:
Note:
1. There is tiny communication overhead with threads, but you
should avoid starting many threads in your main process to avoid GIL.
should avoid starting many threads in your main process to avoid GIL.
The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(ThreadedMapData(...), 1)`` to avoid GIL.
The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(ThreadedMapData(...), 1)`` to avoid GIL.
2. Threads run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be
......
......@@ -165,22 +165,38 @@ def ensure_proc_terminate(proc):
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def is_main_thread():
if six.PY2:
return isinstance(threading.current_thread(), threading._MainThread)
else:
# a nicer solution with py3
return threading.current_thread() == threading.main_thread()
@contextmanager
def mask_sigint():
"""
Returns:
a context where ``SIGINT`` is ignored.
If called in main thread, returns a context where ``SIGINT`` is ignored, and yield True.
Otherwise yield False.
"""
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
yield
signal.signal(signal.SIGINT, sigint_handler)
if is_main_thread():
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
yield True
signal.signal(signal.SIGINT, sigint_handler)
else:
yield False
def start_proc_mask_signal(proc):
""" Start process(es) with SIGINT ignored.
"""
Start process(es) with SIGINT ignored.
Args:
proc: (multiprocessing.Process or list)
Note:
The signal mask is only applied when called from main thread.
"""
if not isinstance(proc, list):
proc = [proc]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment