Commit fa0f57da authored by Yuxin Wu's avatar Yuxin Wu

Use python-prctl to guarantee child process cleanup (fix #668)

parent 1a1ec3db
......@@ -24,9 +24,9 @@ because the bottleneck in this implementation is not computation but simulation.
Some practicical notes:
1. Prefer Python 3; Windows not supported.
2. Occasionally, processes may not get terminated completely. It is suggested to use `systemd-run` to run any
multiprocess Python program to get a cgroup dedicated for the task.
3. Training with a significant slower speed (e.g. on CPU) will result in very bad score, probably because of the slightly off-policy implementation.
2. Training with a significant slower speed (e.g. on CPU) will result in very bad score, probably because of the slightly off-policy implementation.
3. Occasionally, processes may not get terminated completely.
If you're using Linux, install [python-prctl](https://pypi.org/project/python-prctl/) to prevent this.
### To test a model:
......
......@@ -16,7 +16,8 @@ import zmq
from tensorpack.utils import logger
from tensorpack.utils.serialize import loads, dumps
from tensorpack.utils.concurrency import LoopThread, ensure_proc_terminate
from tensorpack.utils.concurrency import (
LoopThread, ensure_proc_terminate, enable_death_signal)
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange',
......@@ -65,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
self.s2c = pipe_s2c
def run(self):
enable_death_signal()
player = self._build_player()
context = zmq.Context()
c2s_socket = context.socket(zmq.PUSH)
......
......@@ -65,7 +65,13 @@ class DataFlow(object):
def reset_state(self):
"""
Reset state of the dataflow. It has to be called before producing datapoints.
Reset state of the dataflow.
It **has to** be called once and only once before producing datapoints.
Note:
1. If the dataflow is forked, each process will call this method
before producing datapoints.
2. The caller thread of this method must remain alive to keep this dataflow alive.
For example, RNG **has to** be reset if used in the DataFlow,
otherwise it won't work well with prefetching, because different
......
......@@ -16,6 +16,7 @@ import atexit
from .base import DataFlow, ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal,
enable_death_signal,
StoppableThread)
from ..utils.serialize import loads, dumps
from ..utils import logger
......@@ -36,8 +37,9 @@ def _bind_guard(sock, name):
sock.bind(name)
except zmq.ZMQError:
logger.error(
"ZMQError in socket.bind(). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information.")
"ZMQError in socket.bind('{}'). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ \
for more information.".format(name))
raise
......@@ -153,6 +155,7 @@ class MultiProcessPrefetchData(ProxyDataFlow):
self.queue = queue
def run(self):
enable_death_signal()
# reset all ds so each process will produce different data
self.ds.reset_state()
while True:
......@@ -250,6 +253,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
self.hwm = hwm
def run(self):
enable_death_signal()
self.ds.reset_state()
context = zmq.Context()
socket = context.socket(zmq.PUSH)
......
......@@ -10,7 +10,7 @@ from six.moves import queue
import zmq
from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard
from ..utils.concurrency import StoppableThread
from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils import logger
from ..utils.serialize import loads, dumps
......@@ -225,6 +225,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self.hwm = hwm
def run(self):
enable_death_signal()
ctx = zmq.Context()
socket = ctx.socket(zmq.REP)
socket.setsockopt(zmq.IDENTITY, self.identity)
......
......@@ -171,6 +171,21 @@ def ensure_proc_terminate(proc):
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def enable_death_signal():
"""
Set the "death signal" of the current process, so that
the current process will be cleaned with guarantee
in case the parent dies accidentally.
"""
try:
import prctl
except ImportError:
return
else:
# is SIGHUP a good choice?
prctl.set_pdeathsig(signal.SIGHUP)
def is_main_thread():
if six.PY2:
return isinstance(threading.current_thread(), threading._MainThread)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment