Commit f6313a07 authored by Yuxin Wu's avatar Yuxin Wu

Move dataflow reentrant guard to reset_state() since it's not pickleable

parent e25cbf1a
......@@ -193,11 +193,11 @@ class BatchDataByShape(BatchData):
"""
super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False)
self.idx = idx
self._guard = DataFlowReentrantGuard()
def reset_state(self):
super(BatchDataByShape, self).reset_state()
self.holder = defaultdict(list)
self._guard = DataFlowReentrantGuard()
def __iter__(self):
with self._guard:
......@@ -235,7 +235,6 @@ class FixedSizeData(ProxyDataFlow):
super(FixedSizeData, self).__init__(ds)
self._size = int(size)
self.itr = None
self._guard = DataFlowReentrantGuard()
self._keep = keep_state
def __len__(self):
......@@ -244,6 +243,7 @@ class FixedSizeData(ProxyDataFlow):
def reset_state(self):
super(FixedSizeData, self).reset_state()
self.itr = self.ds.__iter__()
self._guard = DataFlowReentrantGuard()
def __iter__(self):
with self._guard:
......@@ -625,9 +625,9 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self.shuffle_interval = shuffle_interval
self.nr_reuse = nr_reuse
self._inf_ds = RepeatedData(ds, -1)
self._guard = DataFlowReentrantGuard()
def reset_state(self):
self._guard = DataFlowReentrantGuard()
ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self)
self._iter_cnt = 0
......@@ -664,11 +664,11 @@ class CacheData(ProxyDataFlow):
shuffle (bool): whether to shuffle the datapoints before producing them.
"""
self.shuffle = shuffle
self._guard = DataFlowReentrantGuard()
super(CacheData, self).__init__(ds)
def reset_state(self):
super(CacheData, self).reset_state()
self._guard = DataFlowReentrantGuard()
if self.shuffle:
self.rng = get_rng(self)
self.buffer = []
......
......@@ -90,7 +90,6 @@ class LMDBData(RNGDataFlow):
self._size = self._txn.stat()['entries']
self._set_keys(keys)
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
self._guard = DataFlowReentrantGuard()
def _set_keys(self, keys=None):
def find_keys(txn, size):
......@@ -128,6 +127,7 @@ class LMDBData(RNGDataFlow):
self._txn = self._lmdb.begin()
def reset_state(self):
self._guard = DataFlowReentrantGuard()
self._lmdb.close()
super(LMDBData, self).reset_state()
self._open_lmdb()
......
......@@ -306,7 +306,6 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
self.nr_proc = nr_proc
self._hwm = hwm
self._guard = DataFlowReentrantGuard()
if nr_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
......@@ -330,6 +329,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
def reset_state(self):
super(PrefetchDataZMQ, self).reset_state()
self._guard = DataFlowReentrantGuard()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(self._hwm)
......
......@@ -279,11 +279,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self.map_func = map_func
self._strict = strict
self._procs = []
self._guard = DataFlowReentrantGuard()
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
_ParallelMapData.reset_state(self)
self._guard = DataFlowReentrantGuard()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
......@@ -369,7 +369,6 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
processes=nr_proc,
initializer=_init_pool,
initargs=(self._shared_mem, id_queue, map_func))
self._guard = DataFlowReentrantGuard()
def _create_shared_arr(self):
TYPE = {
......@@ -388,6 +387,7 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
def reset_state(self):
self.ds.reset_state()
self._guard = DataFlowReentrantGuard()
def __iter__(self):
ds_itr = _repeat_iter(self.ds.get_data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment