Commit 1d3ab162 authored by Yuxin Wu's avatar Yuxin Wu

allow mask_sigint called in non-main threads. avoid #368 and enable nested PrefetchDataZMQ

parent 3d9dc7d0
......@@ -102,9 +102,13 @@ class PrefetchProcessZMQ(mp.Process):
self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(self.hwm)
self.socket.connect(self.conn_name)
try:
while True:
for dp in self.ds.get_data():
self.socket.send(dumps(dp), copy=False)
# sigint could still propagate here, e.g. when nested
except KeyboardInterrupt:
pass
class PrefetchDataZMQ(ProxyDataFlow):
......@@ -117,7 +121,9 @@ class PrefetchDataZMQ(ProxyDataFlow):
Note:
1. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe.
2. This dataflow is not fork-safe. You cannot nest it.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchDataZMQ(df, nr_proc=a), nr_proc=b)``.
A total of ``a * b`` instances of ``df`` worker processes will be created.
Also in this case some zmq pipes cannot be cleaned at exit.
3. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
......@@ -139,21 +145,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.nr_proc = nr_proc
self._hwm = hwm
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(self._hwm)
self.socket.bind(self.pipename)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename, self._hwm)
for _ in range(self.nr_proc)]
self.start_processes()
# __del__ not guranteed to get called at exit
import atexit
atexit.register(lambda x: x.__del__(), self)
self._setup_done = False
def get_data(self):
try:
......@@ -179,13 +171,32 @@ class PrefetchDataZMQ(ProxyDataFlow):
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
"""
pass
if self._setup_done:
return
self._setup_done = True
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(self._hwm)
self.socket.bind(self.pipename)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename, self._hwm)
for _ in range(self.nr_proc)]
self.start_processes()
# __del__ not guranteed to get called at exit
import atexit
atexit.register(lambda x: x.__del__(), self)
def start_processes(self):
start_proc_mask_signal(self.procs)
def __del__(self):
# on exit, logger may not be functional anymore
if not self._setup_done:
return
if not self.context.closed:
self.context.destroy(0)
for x in self.procs:
......@@ -226,7 +237,7 @@ class ThreadedMapData(ProxyDataFlow):
This is useful when the mapping function is the bottleneck, but you don't
want to start processes for the entire dataflow pipeline.
Notes:
Note:
1. There is tiny communication overhead with threads, but you
should avoid starting many threads in your main process to avoid GIL.
......
......@@ -165,22 +165,38 @@ def ensure_proc_terminate(proc):
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def is_main_thread():
if six.PY2:
return isinstance(threading.current_thread(), threading._MainThread)
else:
# a nicer solution with py3
return threading.current_thread() == threading.main_thread()
@contextmanager
def mask_sigint():
"""
Returns:
a context where ``SIGINT`` is ignored.
If called in main thread, returns a context where ``SIGINT`` is ignored, and yield True.
Otherwise yield False.
"""
if is_main_thread():
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
yield
yield True
signal.signal(signal.SIGINT, sigint_handler)
else:
yield False
def start_proc_mask_signal(proc):
""" Start process(es) with SIGINT ignored.
"""
Start process(es) with SIGINT ignored.
Args:
proc: (multiprocessing.Process or list)
Note:
The signal mask is only applied when called from main thread.
"""
if not isinstance(proc, list):
proc = [proc]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment