Commit a761a839 authored by Yuxin Wu's avatar Yuxin Wu

some more improvements in data reading

parent da5e9e66
...@@ -111,7 +111,7 @@ class LMDBData(RNGDataFlow): ...@@ -111,7 +111,7 @@ class LMDBData(RNGDataFlow):
def open_lmdb(self): def open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path, self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path), subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False, readonly=True, lock=False, readahead=True,
map_size=1099511627776 * 2, max_readers=100) map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin() self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries'] self._size = self._txn.stat()['entries']
......
...@@ -70,8 +70,8 @@ class AugmentImageComponent(MapDataComponent): ...@@ -70,8 +70,8 @@ class AugmentImageComponent(MapDataComponent):
raise raise
except Exception: except Exception:
self._nr_error += 1 self._nr_error += 1
if self._nr_error % 1000 == 0: if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.warn("Got {} augmentation errors.".format(self._nr_error)) logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None return None
return ret return ret
...@@ -111,8 +111,8 @@ class AugmentImageComponents(MapData): ...@@ -111,8 +111,8 @@ class AugmentImageComponents(MapData):
raise raise
except Exception: except Exception:
self._nr_error += 1 self._nr_error += 1
if self._nr_error % 1000 == 0: if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.warn("Got {} augmentation errors.".format(self._nr_error)) logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None return None
super(AugmentImageComponents, self).__init__(ds, func) super(AugmentImageComponents, self).__init__(ds, func)
......
...@@ -19,7 +19,8 @@ from ..utils.serialize import loads, dumps ...@@ -19,7 +19,8 @@ from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs'] __all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs',
'ThreadedMapData', 'StartNewProcess']
class PrefetchProcess(mp.Process): class PrefetchProcess(mp.Process):
...@@ -68,8 +69,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -68,8 +69,7 @@ class PrefetchData(ProxyDataFlow):
self.procs = [PrefetchProcess(self.ds, self.queue) self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
ensure_proc_terminate(self.procs) ensure_proc_terminate(self.procs)
for x in self.procs: start_proc_mask_signal(self.procs)
x.start()
def get_data(self): def get_data(self):
for k in itertools.count(): for k in itertools.count():
...@@ -105,6 +105,9 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -105,6 +105,9 @@ class PrefetchDataZMQ(ProxyDataFlow):
""" """
Prefetch data from a DataFlow using multiple processes, with ZMQ for Prefetch data from a DataFlow using multiple processes, with ZMQ for
communication. communication.
Note that this dataflow is not fork-safe. You cannot nest this dataflow
into another PrefetchDataZMQ or PrefetchData.
""" """
def __init__(self, ds, nr_proc=1, pipedir=None, hwm=50): def __init__(self, ds, nr_proc=1, pipedir=None, hwm=50):
""" """
...@@ -262,3 +265,21 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -262,3 +265,21 @@ class ThreadedMapData(ProxyDataFlow):
for _ in range(sz): for _ in range(sz):
self._in_queue.put(next(self._itr)) self._in_queue.put(next(self._itr))
yield self._out_queue.get() yield self._out_queue.get()
def StartNewProcess(ds, queue_size):
"""
Run ds in a new process, and use multiprocessing.queue to send data back.
Args:
ds (DataFlow): a DataFlow.
queue_size (int): the size of queue.
Returns:
a fork-safe DataFlow, therefore is safe to use under another PrefetchData or
PrefetchDataZMQ.
Note:
There could be a zmq version of this in the future.
"""
return PrefetchData(ds, queue_size, 1)
...@@ -147,7 +147,6 @@ class QueueInput(FeedfreeInput): ...@@ -147,7 +147,6 @@ class QueueInput(FeedfreeInput):
def _get_input_tensors(self): def _get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
print(ret)
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
assert len(ret) == len(self.input_placehdrs) assert len(ret) == len(self.input_placehdrs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment