Commit 79b9d0eb authored by Yuxin Wu's avatar Yuxin Wu

Guard some stateful dataflow with non-reentrancy

parent 6c905896
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import threading
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import six import six
from ..utils.utils import get_rng from ..utils.utils import get_rng
...@@ -20,6 +20,23 @@ class DataFlowTerminated(BaseException): ...@@ -20,6 +20,23 @@ class DataFlowTerminated(BaseException):
pass pass
class DataFlowReentrantGuard(object):
"""
A tool to enforce thread-level non-reentrancy on DataFlow.
"""
def __init__(self):
self._lock = threading.Lock()
def __enter__(self):
self._succ = self._lock.acquire(blocking=False)
if not self._succ:
raise threading.ThreadError("This DataFlow cannot be reused under different threads!")
def __exit__(self, exc_type, exc_val, exc_tb):
self._lock.release()
return False
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """ """ Base class for all DataFlow """
......
...@@ -128,6 +128,7 @@ class ILSVRC12Files(RNGDataFlow): ...@@ -128,6 +128,7 @@ class ILSVRC12Files(RNGDataFlow):
self.imglist = meta.get_image_list(name, dir_structure) self.imglist = meta.get_image_list(name, dir_structure)
for fname, _ in self.imglist[:10]: for fname, _ in self.imglist[:10]:
fname = os.path.join(self.full_dir, fname)
assert os.path.isfile(fname), fname assert os.path.isfile(fname), fname
def size(self): def size(self):
......
...@@ -11,7 +11,7 @@ import uuid ...@@ -11,7 +11,7 @@ import uuid
import os import os
import zmq import zmq
from .base import ProxyDataFlow, DataFlowTerminated from .base import ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate, from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal, mask_sigint, start_proc_mask_signal,
StoppableThread) StoppableThread)
...@@ -74,6 +74,8 @@ class PrefetchData(ProxyDataFlow): ...@@ -74,6 +74,8 @@ class PrefetchData(ProxyDataFlow):
self._size = -1 self._size = -1
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch self.nr_prefetch = nr_prefetch
self._guard = DataFlowReentrantGuard()
self.queue = mp.Queue(self.nr_prefetch) self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue) self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
...@@ -81,6 +83,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -81,6 +83,7 @@ class PrefetchData(ProxyDataFlow):
start_proc_mask_signal(self.procs) start_proc_mask_signal(self.procs)
def get_data(self): def get_data(self):
with self._guard:
for k in itertools.count(): for k in itertools.count():
if self._size > 0 and k >= self._size: if self._size > 0 and k >= self._size:
break break
...@@ -155,9 +158,11 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -155,9 +158,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.nr_proc = nr_proc self.nr_proc = nr_proc
self._hwm = hwm self._hwm = hwm
self._guard = DataFlowReentrantGuard()
self._setup_done = False self._setup_done = False
def get_data(self): def get_data(self):
with self._guard:
try: try:
for k in itertools.count(): for k in itertools.count():
if self._size > 0 and k >= self._size: if self._size > 0 and k >= self._size:
...@@ -315,6 +320,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -315,6 +320,7 @@ class ThreadedMapData(ProxyDataFlow):
t.start() t.start()
self._iter = self._iter_ds.get_data() self._iter = self._iter_ds.get_data()
self._guard = DataFlowReentrantGuard()
# only call once, to ensure inq+outq has a total of buffer_size elements # only call once, to ensure inq+outq has a total of buffer_size elements
self._fill_buffer() self._fill_buffer()
...@@ -332,6 +338,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -332,6 +338,7 @@ class ThreadedMapData(ProxyDataFlow):
raise raise
def get_data(self): def get_data(self):
with self._guard:
for dp in self._iter: for dp in self._iter:
self._in_queue.put(dp) self._in_queue.put(dp)
yield self._out_queue.get() yield self._out_queue.get()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment