Commit ba9d1793 authored by Yuxin Wu's avatar Yuxin Wu

exception handling in ZMQ runner

parent be51dd88
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
# File: parallel.py # File: parallel.py
import atexit import atexit
import pickle
import errno import errno
import traceback
import itertools import itertools
import multiprocessing as mp import multiprocessing as mp
import os import os
...@@ -25,6 +27,25 @@ __all__ = ['PrefetchData', 'MultiProcessPrefetchData', ...@@ -25,6 +27,25 @@ __all__ = ['PrefetchData', 'MultiProcessPrefetchData',
'PrefetchDataZMQ', 'MultiThreadPrefetchData'] 'PrefetchDataZMQ', 'MultiThreadPrefetchData']
# from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/__init__.py
class _ExceptionWrapper:
MAGIC = b"EXC_MAGIC"
"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info):
# It is important that we don't store exc_info, see
# NOTE [ Python Traceback Reference Cycle Problem ]
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
def pack(self):
return self.MAGIC + pickle.dumps(self)
@staticmethod
def unpack(dp):
if isinstance(dp, bytes) and dp.startswith(_ExceptionWrapper.MAGIC):
return pickle.loads(dp[len(_ExceptionWrapper.MAGIC):])
def _repeat_iter(get_itr): def _repeat_iter(get_itr):
while True: while True:
for x in get_itr(): for x in get_itr():
...@@ -291,14 +312,21 @@ class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow): ...@@ -291,14 +312,21 @@ class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow):
def run(self): def run(self):
enable_death_signal(_warn=self.idx == 0) enable_death_signal(_warn=self.idx == 0)
self.ds.reset_state() self.ds.reset_state()
itr = _repeat_iter(lambda: self.ds)
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.PUSH) socket = context.socket(zmq.PUSH)
socket.set_hwm(self.hwm) socket.set_hwm(self.hwm)
socket.connect(self.conn_name) socket.connect(self.conn_name)
try: try:
while True: while True:
for dp in self.ds: try:
dp = next(itr)
socket.send(dumps(dp), copy=False)
except Exception:
dp = _ExceptionWrapper(sys.exc_info()).pack()
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
raise
# sigint could still propagate here, e.g. when nested # sigint could still propagate here, e.g. when nested
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
...@@ -332,7 +360,12 @@ class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow): ...@@ -332,7 +360,12 @@ class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow):
self._size = -1 self._size = -1
def _recv(self): def _recv(self):
return loads(self.socket.recv(copy=False)) ret = loads(self.socket.recv(copy=False))
exc = _ExceptionWrapper.unpack(ret)
if exc is not None:
logger.error("Exception '{}' in worker:".format(str(exc.exc_type)))
raise exc.exc_type(exc.exc_msg)
return ret
def __len__(self): def __len__(self):
return self.ds.__len__() return self.ds.__len__()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment