Commit a761a839 authored by Yuxin Wu's avatar Yuxin Wu

some more improvements in data reading

parent da5e9e66
......@@ -111,7 +111,7 @@ class LMDBData(RNGDataFlow):
def open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False,
readonly=True, lock=False, readahead=True,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries']
......
......@@ -70,8 +70,8 @@ class AugmentImageComponent(MapDataComponent):
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0:
logger.warn("Got {} augmentation errors.".format(self._nr_error))
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
return ret
......@@ -111,8 +111,8 @@ class AugmentImageComponents(MapData):
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0:
logger.warn("Got {} augmentation errors.".format(self._nr_error))
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
super(AugmentImageComponents, self).__init__(ds, func)
......
......@@ -19,7 +19,8 @@ from ..utils.serialize import loads, dumps
from ..utils import logger
from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs']
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs',
'ThreadedMapData', 'StartNewProcess']
class PrefetchProcess(mp.Process):
......@@ -68,8 +69,7 @@ class PrefetchData(ProxyDataFlow):
self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)]
ensure_proc_terminate(self.procs)
for x in self.procs:
x.start()
start_proc_mask_signal(self.procs)
def get_data(self):
for k in itertools.count():
......@@ -105,6 +105,9 @@ class PrefetchDataZMQ(ProxyDataFlow):
"""
Prefetch data from a DataFlow using multiple processes, with ZMQ for
communication.
Note that this dataflow is not fork-safe. You cannot nest this dataflow
into another PrefetchDataZMQ or PrefetchData.
"""
def __init__(self, ds, nr_proc=1, pipedir=None, hwm=50):
"""
......@@ -262,3 +265,21 @@ class ThreadedMapData(ProxyDataFlow):
for _ in range(sz):
self._in_queue.put(next(self._itr))
yield self._out_queue.get()
def StartNewProcess(ds, queue_size):
"""
Run ds in a new process, and use multiprocessing.queue to send data back.
Args:
ds (DataFlow): a DataFlow.
queue_size (int): the size of queue.
Returns:
a fork-safe DataFlow, therefore is safe to use under another PrefetchData or
PrefetchDataZMQ.
Note:
There could be a zmq version of this in the future.
"""
return PrefetchData(ds, queue_size, 1)
......@@ -147,7 +147,6 @@ class QueueInput(FeedfreeInput):
def _get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque')
print(ret)
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment