Commit f6313a07 authored by Yuxin Wu's avatar Yuxin Wu

Move dataflow reentrant guard to reset_state() since it's not pickleable

parent e25cbf1a
...@@ -193,11 +193,11 @@ class BatchDataByShape(BatchData): ...@@ -193,11 +193,11 @@ class BatchDataByShape(BatchData):
""" """
super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False) super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False)
self.idx = idx self.idx = idx
self._guard = DataFlowReentrantGuard()
def reset_state(self): def reset_state(self):
super(BatchDataByShape, self).reset_state() super(BatchDataByShape, self).reset_state()
self.holder = defaultdict(list) self.holder = defaultdict(list)
self._guard = DataFlowReentrantGuard()
def __iter__(self): def __iter__(self):
with self._guard: with self._guard:
...@@ -235,7 +235,6 @@ class FixedSizeData(ProxyDataFlow): ...@@ -235,7 +235,6 @@ class FixedSizeData(ProxyDataFlow):
super(FixedSizeData, self).__init__(ds) super(FixedSizeData, self).__init__(ds)
self._size = int(size) self._size = int(size)
self.itr = None self.itr = None
self._guard = DataFlowReentrantGuard()
self._keep = keep_state self._keep = keep_state
def __len__(self): def __len__(self):
...@@ -244,6 +243,7 @@ class FixedSizeData(ProxyDataFlow): ...@@ -244,6 +243,7 @@ class FixedSizeData(ProxyDataFlow):
def reset_state(self): def reset_state(self):
super(FixedSizeData, self).reset_state() super(FixedSizeData, self).reset_state()
self.itr = self.ds.__iter__() self.itr = self.ds.__iter__()
self._guard = DataFlowReentrantGuard()
def __iter__(self): def __iter__(self):
with self._guard: with self._guard:
...@@ -625,9 +625,9 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -625,9 +625,9 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self.shuffle_interval = shuffle_interval self.shuffle_interval = shuffle_interval
self.nr_reuse = nr_reuse self.nr_reuse = nr_reuse
self._inf_ds = RepeatedData(ds, -1) self._inf_ds = RepeatedData(ds, -1)
self._guard = DataFlowReentrantGuard()
def reset_state(self): def reset_state(self):
self._guard = DataFlowReentrantGuard()
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self) RNGDataFlow.reset_state(self)
self._iter_cnt = 0 self._iter_cnt = 0
...@@ -664,11 +664,11 @@ class CacheData(ProxyDataFlow): ...@@ -664,11 +664,11 @@ class CacheData(ProxyDataFlow):
shuffle (bool): whether to shuffle the datapoints before producing them. shuffle (bool): whether to shuffle the datapoints before producing them.
""" """
self.shuffle = shuffle self.shuffle = shuffle
self._guard = DataFlowReentrantGuard()
super(CacheData, self).__init__(ds) super(CacheData, self).__init__(ds)
def reset_state(self): def reset_state(self):
super(CacheData, self).reset_state() super(CacheData, self).reset_state()
self._guard = DataFlowReentrantGuard()
if self.shuffle: if self.shuffle:
self.rng = get_rng(self) self.rng = get_rng(self)
self.buffer = [] self.buffer = []
......
...@@ -90,7 +90,6 @@ class LMDBData(RNGDataFlow): ...@@ -90,7 +90,6 @@ class LMDBData(RNGDataFlow):
self._size = self._txn.stat()['entries'] self._size = self._txn.stat()['entries']
self._set_keys(keys) self._set_keys(keys)
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path)) logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
self._guard = DataFlowReentrantGuard()
def _set_keys(self, keys=None): def _set_keys(self, keys=None):
def find_keys(txn, size): def find_keys(txn, size):
...@@ -128,6 +127,7 @@ class LMDBData(RNGDataFlow): ...@@ -128,6 +127,7 @@ class LMDBData(RNGDataFlow):
self._txn = self._lmdb.begin() self._txn = self._lmdb.begin()
def reset_state(self): def reset_state(self):
self._guard = DataFlowReentrantGuard()
self._lmdb.close() self._lmdb.close()
super(LMDBData, self).reset_state() super(LMDBData, self).reset_state()
self._open_lmdb() self._open_lmdb()
......
...@@ -306,7 +306,6 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -306,7 +306,6 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
self.nr_proc = nr_proc self.nr_proc = nr_proc
self._hwm = hwm self._hwm = hwm
self._guard = DataFlowReentrantGuard()
if nr_proc > 1: if nr_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. " logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.") "This assumes the datapoints are i.i.d.")
...@@ -330,6 +329,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -330,6 +329,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
def reset_state(self): def reset_state(self):
super(PrefetchDataZMQ, self).reset_state() super(PrefetchDataZMQ, self).reset_state()
self._guard = DataFlowReentrantGuard()
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL) self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(self._hwm) self.socket.set_hwm(self._hwm)
......
...@@ -279,11 +279,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -279,11 +279,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self.map_func = map_func self.map_func = map_func
self._strict = strict self._strict = strict
self._procs = [] self._procs = []
self._guard = DataFlowReentrantGuard()
def reset_state(self): def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self) _MultiProcessZMQDataFlow.reset_state(self)
_ParallelMapData.reset_state(self) _ParallelMapData.reset_state(self)
self._guard = DataFlowReentrantGuard()
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER) self.socket = self.context.socket(zmq.DEALER)
...@@ -369,7 +369,6 @@ class MultiProcessMapDataComponentSharedArray(DataFlow): ...@@ -369,7 +369,6 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
processes=nr_proc, processes=nr_proc,
initializer=_init_pool, initializer=_init_pool,
initargs=(self._shared_mem, id_queue, map_func)) initargs=(self._shared_mem, id_queue, map_func))
self._guard = DataFlowReentrantGuard()
def _create_shared_arr(self): def _create_shared_arr(self):
TYPE = { TYPE = {
...@@ -388,6 +387,7 @@ class MultiProcessMapDataComponentSharedArray(DataFlow): ...@@ -388,6 +387,7 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
self._guard = DataFlowReentrantGuard()
def __iter__(self): def __iter__(self):
ds_itr = _repeat_iter(self.ds.get_data) ds_itr = _repeat_iter(self.ds.get_data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment