Commit 79b9d0eb authored by Yuxin Wu's avatar Yuxin Wu

Guard some stateful dataflow with non-reentrancy

parent 6c905896
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import threading
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import six import six
from ..utils.utils import get_rng from ..utils.utils import get_rng
...@@ -20,6 +20,23 @@ class DataFlowTerminated(BaseException): ...@@ -20,6 +20,23 @@ class DataFlowTerminated(BaseException):
pass pass
class DataFlowReentrantGuard(object):
"""
A tool to enforce thread-level non-reentrancy on DataFlow.
"""
def __init__(self):
self._lock = threading.Lock()
def __enter__(self):
self._succ = self._lock.acquire(blocking=False)
if not self._succ:
raise threading.ThreadError("This DataFlow cannot be reused under different threads!")
def __exit__(self, exc_type, exc_val, exc_tb):
self._lock.release()
return False
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """ """ Base class for all DataFlow """
......
...@@ -128,6 +128,7 @@ class ILSVRC12Files(RNGDataFlow): ...@@ -128,6 +128,7 @@ class ILSVRC12Files(RNGDataFlow):
self.imglist = meta.get_image_list(name, dir_structure) self.imglist = meta.get_image_list(name, dir_structure)
for fname, _ in self.imglist[:10]: for fname, _ in self.imglist[:10]:
fname = os.path.join(self.full_dir, fname)
assert os.path.isfile(fname), fname assert os.path.isfile(fname), fname
def size(self): def size(self):
......
...@@ -11,7 +11,7 @@ import uuid ...@@ -11,7 +11,7 @@ import uuid
import os import os
import zmq import zmq
from .base import ProxyDataFlow, DataFlowTerminated from .base import ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate, from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal, mask_sigint, start_proc_mask_signal,
StoppableThread) StoppableThread)
...@@ -74,6 +74,8 @@ class PrefetchData(ProxyDataFlow): ...@@ -74,6 +74,8 @@ class PrefetchData(ProxyDataFlow):
self._size = -1 self._size = -1
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch self.nr_prefetch = nr_prefetch
self._guard = DataFlowReentrantGuard()
self.queue = mp.Queue(self.nr_prefetch) self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue) self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
...@@ -81,11 +83,12 @@ class PrefetchData(ProxyDataFlow): ...@@ -81,11 +83,12 @@ class PrefetchData(ProxyDataFlow):
start_proc_mask_signal(self.procs) start_proc_mask_signal(self.procs)
def get_data(self): def get_data(self):
for k in itertools.count(): with self._guard:
if self._size > 0 and k >= self._size: for k in itertools.count():
break if self._size > 0 and k >= self._size:
dp = self.queue.get() break
yield dp dp = self.queue.get()
yield dp
def reset_state(self): def reset_state(self):
# do nothing. all ds are reset once and only once in spawned processes # do nothing. all ds are reset once and only once in spawned processes
...@@ -155,26 +158,28 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -155,26 +158,28 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.nr_proc = nr_proc self.nr_proc = nr_proc
self._hwm = hwm self._hwm = hwm
self._guard = DataFlowReentrantGuard()
self._setup_done = False self._setup_done = False
def get_data(self): def get_data(self):
try: with self._guard:
for k in itertools.count(): try:
if self._size > 0 and k >= self._size: for k in itertools.count():
break if self._size > 0 and k >= self._size:
dp = loads(self.socket.recv(copy=False).bytes) break
yield dp dp = loads(self.socket.recv(copy=False).bytes)
except zmq.ContextTerminated: yield dp
logger.info("[Prefetch Master] Context terminated.") except zmq.ContextTerminated:
raise DataFlowTerminated() logger.info("[Prefetch Master] Context terminated.")
except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[Prefetch Master] Socket closed.")
raise DataFlowTerminated() raise DataFlowTerminated()
else: except zmq.ZMQError as e:
if e.errno == errno.ENOTSOCK: # socket closed
logger.info("[Prefetch Master] Socket closed.")
raise DataFlowTerminated()
else:
raise
except:
raise raise
except:
raise
def reset_state(self): def reset_state(self):
""" """
...@@ -315,6 +320,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -315,6 +320,7 @@ class ThreadedMapData(ProxyDataFlow):
t.start() t.start()
self._iter = self._iter_ds.get_data() self._iter = self._iter_ds.get_data()
self._guard = DataFlowReentrantGuard()
# only call once, to ensure inq+outq has a total of buffer_size elements # only call once, to ensure inq+outq has a total of buffer_size elements
self._fill_buffer() self._fill_buffer()
...@@ -332,23 +338,24 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -332,23 +338,24 @@ class ThreadedMapData(ProxyDataFlow):
raise raise
def get_data(self): def get_data(self):
for dp in self._iter: with self._guard:
self._in_queue.put(dp) for dp in self._iter:
yield self._out_queue.get() self._in_queue.put(dp)
self._iter = self._iter_ds.get_data()
if self._strict:
# first call get() to clear the queues, then fill
for k in range(self.buffer_size):
dp = self._out_queue.get()
if k == self.buffer_size - 1:
self._fill_buffer()
yield dp
else:
for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter))
yield self._out_queue.get() yield self._out_queue.get()
self._iter = self._iter_ds.get_data()
if self._strict:
# first call get() to clear the queues, then fill
for k in range(self.buffer_size):
dp = self._out_queue.get()
if k == self.buffer_size - 1:
self._fill_buffer()
yield dp
else:
for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter))
yield self._out_queue.get()
def __del__(self): def __del__(self):
for p in self._threads: for p in self._threads:
p.stop() p.stop()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment