Commit 79b9d0eb authored by Yuxin Wu's avatar Yuxin Wu

Guard some stateful dataflow with non-reentrancy

parent 6c905896
......@@ -3,7 +3,7 @@
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import threading
from abc import abstractmethod, ABCMeta
import six
from ..utils.utils import get_rng
......@@ -20,6 +20,23 @@ class DataFlowTerminated(BaseException):
pass
class DataFlowReentrantGuard(object):
"""
A tool to enforce thread-level non-reentrancy on DataFlow.
"""
def __init__(self):
self._lock = threading.Lock()
def __enter__(self):
self._succ = self._lock.acquire(blocking=False)
if not self._succ:
raise threading.ThreadError("This DataFlow cannot be reused under different threads!")
def __exit__(self, exc_type, exc_val, exc_tb):
self._lock.release()
return False
@six.add_metaclass(ABCMeta)
class DataFlow(object):
""" Base class for all DataFlow """
......
......@@ -128,6 +128,7 @@ class ILSVRC12Files(RNGDataFlow):
self.imglist = meta.get_image_list(name, dir_structure)
for fname, _ in self.imglist[:10]:
fname = os.path.join(self.full_dir, fname)
assert os.path.isfile(fname), fname
def size(self):
......
......@@ -11,7 +11,7 @@ import uuid
import os
import zmq
from .base import ProxyDataFlow, DataFlowTerminated
from .base import ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal,
StoppableThread)
......@@ -74,6 +74,8 @@ class PrefetchData(ProxyDataFlow):
self._size = -1
self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch
self._guard = DataFlowReentrantGuard()
self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)]
......@@ -81,11 +83,12 @@ class PrefetchData(ProxyDataFlow):
start_proc_mask_signal(self.procs)
def get_data(self):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = self.queue.get()
yield dp
with self._guard:
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = self.queue.get()
yield dp
def reset_state(self):
# do nothing. all ds are reset once and only once in spawned processes
......@@ -155,26 +158,28 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.nr_proc = nr_proc
self._hwm = hwm
self._guard = DataFlowReentrantGuard()
self._setup_done = False
def get_data(self):
try:
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = loads(self.socket.recv(copy=False).bytes)
yield dp
except zmq.ContextTerminated:
logger.info("[Prefetch Master] Context terminated.")
raise DataFlowTerminated()
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[Prefetch Master] Socket closed.")
with self._guard:
try:
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
dp = loads(self.socket.recv(copy=False).bytes)
yield dp
except zmq.ContextTerminated:
logger.info("[Prefetch Master] Context terminated.")
raise DataFlowTerminated()
else:
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[Prefetch Master] Socket closed.")
raise DataFlowTerminated()
else:
raise
except:
raise
except:
raise
def reset_state(self):
"""
......@@ -315,6 +320,7 @@ class ThreadedMapData(ProxyDataFlow):
t.start()
self._iter = self._iter_ds.get_data()
self._guard = DataFlowReentrantGuard()
# only call once, to ensure inq+outq has a total of buffer_size elements
self._fill_buffer()
......@@ -332,23 +338,24 @@ class ThreadedMapData(ProxyDataFlow):
raise
def get_data(self):
for dp in self._iter:
self._in_queue.put(dp)
yield self._out_queue.get()
self._iter = self._iter_ds.get_data()
if self._strict:
# first call get() to clear the queues, then fill
for k in range(self.buffer_size):
dp = self._out_queue.get()
if k == self.buffer_size - 1:
self._fill_buffer()
yield dp
else:
for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter))
with self._guard:
for dp in self._iter:
self._in_queue.put(dp)
yield self._out_queue.get()
self._iter = self._iter_ds.get_data()
if self._strict:
# first call get() to clear the queues, then fill
for k in range(self.buffer_size):
dp = self._out_queue.get()
if k == self.buffer_size - 1:
self._fill_buffer()
yield dp
else:
for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter))
yield self._out_queue.get()
def __del__(self):
for p in self._threads:
p.stop()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment