Commit 8dbb84e2 authored by Yuxin Wu's avatar Yuxin Wu

Allow LMDBData to take DataFlow

parent 661abb69
......@@ -157,9 +157,20 @@ class LMDBDataDecoder(MapData):
class LMDBDataPoint(MapData):
""" Read a LMDB file and produce deserialized datapoints.
"""
Read a LMDB file and produce deserialized datapoints.
It reads the database produced by
:func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`.
Example:
.. code-block:: python
ds = LMDBDataPoint("/data/ImageNet.lmdb", shuffle=False)
# alternatively:
ds = LMDBData("/data/ImageNet.lmdb", shuffle=False)
ds = LocallyShuffleData(ds, 50000)
ds = LMDBDataPoint(ds)
"""
def __init__(self, *args, **kwargs):
......@@ -167,6 +178,10 @@ class LMDBDataPoint(MapData):
Args:
args, kwargs: Same as in :class:`LMDBData`.
"""
if isinstance(args[0], DataFlow):
ds = args[0]
else:
ds = LMDBData(*args, **kwargs)
def f(dp):
......
......@@ -464,7 +464,7 @@ class StagingInputWrapper(FeedfreeInput):
self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
def setup_staging_areas(self):
logger.info("Setting up the StageAreas for GPU prefetching ...")
logger.info("Setting up StagingArea for GPU prefetching ...")
for idx, device in enumerate(self._devices):
with tf.device(device):
inputs = self._input.get_input_tensors()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment