Commit fa0f57da authored by Yuxin Wu's avatar Yuxin Wu

Use python-prctl to guarantee child process cleanup (fix #668)

parent 1a1ec3db
...@@ -24,9 +24,9 @@ because the bottleneck in this implementation is not computation but simulation. ...@@ -24,9 +24,9 @@ because the bottleneck in this implementation is not computation but simulation.
Some practicical notes: Some practicical notes:
1. Prefer Python 3; Windows not supported. 1. Prefer Python 3; Windows not supported.
2. Occasionally, processes may not get terminated completely. It is suggested to use `systemd-run` to run any 2. Training with a significant slower speed (e.g. on CPU) will result in very bad score, probably because of the slightly off-policy implementation.
multiprocess Python program to get a cgroup dedicated for the task. 3. Occasionally, processes may not get terminated completely.
3. Training with a significant slower speed (e.g. on CPU) will result in very bad score, probably because of the slightly off-policy implementation. If you're using Linux, install [python-prctl](https://pypi.org/project/python-prctl/) to prevent this.
### To test a model: ### To test a model:
......
...@@ -16,7 +16,8 @@ import zmq ...@@ -16,7 +16,8 @@ import zmq
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.serialize import loads, dumps from tensorpack.utils.serialize import loads, dumps
from tensorpack.utils.concurrency import LoopThread, ensure_proc_terminate from tensorpack.utils.concurrency import (
LoopThread, ensure_proc_terminate, enable_death_signal)
__all__ = ['SimulatorProcess', 'SimulatorMaster', __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessStateExchange',
...@@ -65,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -65,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
self.s2c = pipe_s2c self.s2c = pipe_s2c
def run(self): def run(self):
enable_death_signal()
player = self._build_player() player = self._build_player()
context = zmq.Context() context = zmq.Context()
c2s_socket = context.socket(zmq.PUSH) c2s_socket = context.socket(zmq.PUSH)
......
...@@ -65,7 +65,13 @@ class DataFlow(object): ...@@ -65,7 +65,13 @@ class DataFlow(object):
def reset_state(self): def reset_state(self):
""" """
Reset state of the dataflow. It has to be called before producing datapoints. Reset state of the dataflow.
It **has to** be called once and only once before producing datapoints.
Note:
1. If the dataflow is forked, each process will call this method
before producing datapoints.
2. The caller thread of this method must remain alive to keep this dataflow alive.
For example, RNG **has to** be reset if used in the DataFlow, For example, RNG **has to** be reset if used in the DataFlow,
otherwise it won't work well with prefetching, because different otherwise it won't work well with prefetching, because different
......
...@@ -16,6 +16,7 @@ import atexit ...@@ -16,6 +16,7 @@ import atexit
from .base import DataFlow, ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard from .base import DataFlow, ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate, from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal, mask_sigint, start_proc_mask_signal,
enable_death_signal,
StoppableThread) StoppableThread)
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger
...@@ -36,8 +37,9 @@ def _bind_guard(sock, name): ...@@ -36,8 +37,9 @@ def _bind_guard(sock, name):
sock.bind(name) sock.bind(name)
except zmq.ZMQError: except zmq.ZMQError:
logger.error( logger.error(
"ZMQError in socket.bind(). Perhaps you're \ "ZMQError in socket.bind('{}'). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information.") using pipes on a non-local file system. See documentation of PrefetchDataZMQ \
for more information.".format(name))
raise raise
...@@ -153,6 +155,7 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -153,6 +155,7 @@ class MultiProcessPrefetchData(ProxyDataFlow):
self.queue = queue self.queue = queue
def run(self): def run(self):
enable_death_signal()
# reset all ds so each process will produce different data # reset all ds so each process will produce different data
self.ds.reset_state() self.ds.reset_state()
while True: while True:
...@@ -250,6 +253,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -250,6 +253,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
self.hwm = hwm self.hwm = hwm
def run(self): def run(self):
enable_death_signal()
self.ds.reset_state() self.ds.reset_state()
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.PUSH) socket = context.socket(zmq.PUSH)
......
...@@ -10,7 +10,7 @@ from six.moves import queue ...@@ -10,7 +10,7 @@ from six.moves import queue
import zmq import zmq
from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard
from ..utils.concurrency import StoppableThread from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils import logger from ..utils import logger
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
...@@ -225,6 +225,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -225,6 +225,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self.hwm = hwm self.hwm = hwm
def run(self): def run(self):
enable_death_signal()
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.REP) socket = ctx.socket(zmq.REP)
socket.setsockopt(zmq.IDENTITY, self.identity) socket.setsockopt(zmq.IDENTITY, self.identity)
......
...@@ -171,6 +171,21 @@ def ensure_proc_terminate(proc): ...@@ -171,6 +171,21 @@ def ensure_proc_terminate(proc):
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def enable_death_signal():
"""
Set the "death signal" of the current process, so that
the current process will be cleaned with guarantee
in case the parent dies accidentally.
"""
try:
import prctl
except ImportError:
return
else:
# is SIGHUP a good choice?
prctl.set_pdeathsig(signal.SIGHUP)
def is_main_thread(): def is_main_thread():
if six.PY2: if six.PY2:
return isinstance(threading.current_thread(), threading._MainThread) return isinstance(threading.current_thread(), threading._MainThread)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment