Commit 79b9d0eb authored by Yuxin Wu's avatar Yuxin Wu

Guard some stateful dataflow with non-reentrancy

parent 6c905896
......@@ -3,7 +3,7 @@
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import threading
from abc import abstractmethod, ABCMeta
import six
from ..utils.utils import get_rng
......@@ -20,6 +20,23 @@ class DataFlowTerminated(BaseException):
pass
class DataFlowReentrantGuard(object):
"""
A tool to enforce thread-level non-reentrancy on DataFlow.
"""
def __init__(self):
self._lock = threading.Lock()
def __enter__(self):
self._succ = self._lock.acquire(blocking=False)
if not self._succ:
raise threading.ThreadError("This DataFlow cannot be reused under different threads!")
def __exit__(self, exc_type, exc_val, exc_tb):
self._lock.release()
return False
@six.add_metaclass(ABCMeta)
class DataFlow(object):
""" Base class for all DataFlow """
......
......@@ -128,6 +128,7 @@ class ILSVRC12Files(RNGDataFlow):
self.imglist = meta.get_image_list(name, dir_structure)
for fname, _ in self.imglist[:10]:
fname = os.path.join(self.full_dir, fname)
assert os.path.isfile(fname), fname
def size(self):
......
......@@ -11,7 +11,7 @@ import uuid
import os
import zmq
from .base import ProxyDataFlow, DataFlowTerminated
from .base import ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal,
StoppableThread)
......@@ -74,6 +74,8 @@ class PrefetchData(ProxyDataFlow):
self._size = -1
self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch
self._guard = DataFlowReentrantGuard()
self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)]
......@@ -81,6 +83,7 @@ class PrefetchData(ProxyDataFlow):
start_proc_mask_signal(self.procs)
def get_data(self):
with self._guard:
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
......@@ -155,9 +158,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.nr_proc = nr_proc
self._hwm = hwm
self._guard = DataFlowReentrantGuard()
self._setup_done = False
def get_data(self):
with self._guard:
try:
for k in itertools.count():
if self._size > 0 and k >= self._size:
......@@ -315,6 +320,7 @@ class ThreadedMapData(ProxyDataFlow):
t.start()
self._iter = self._iter_ds.get_data()
self._guard = DataFlowReentrantGuard()
# only call once, to ensure inq+outq has a total of buffer_size elements
self._fill_buffer()
......@@ -332,6 +338,7 @@ class ThreadedMapData(ProxyDataFlow):
raise
def get_data(self):
with self._guard:
for dp in self._iter:
self._in_queue.put(dp)
yield self._out_queue.get()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment