Commit 8dbb84e2 authored by Yuxin Wu's avatar Yuxin Wu

Allow LMDBData to take DataFlow

parent 661abb69
...@@ -157,17 +157,32 @@ class LMDBDataDecoder(MapData): ...@@ -157,17 +157,32 @@ class LMDBDataDecoder(MapData):
class LMDBDataPoint(MapData): class LMDBDataPoint(MapData):
""" Read a LMDB file and produce deserialized datapoints. """
It reads the database produced by Read a LMDB file and produce deserialized datapoints.
:func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. It reads the database produced by
""" :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`.
Example:
.. code-block:: python
ds = LMDBDataPoint("/data/ImageNet.lmdb", shuffle=False)
# alternatively:
ds = LMDBData("/data/ImageNet.lmdb", shuffle=False)
ds = LocallyShuffleData(ds, 50000)
ds = LMDBDataPoint(ds)
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """
Args: Args:
args, kwargs: Same as in :class:`LMDBData`. args, kwargs: Same as in :class:`LMDBData`.
""" """
ds = LMDBData(*args, **kwargs)
if isinstance(args[0], DataFlow):
ds = args[0]
else:
ds = LMDBData(*args, **kwargs)
def f(dp): def f(dp):
return loads(dp[1]) return loads(dp[1])
......
...@@ -464,7 +464,7 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -464,7 +464,7 @@ class StagingInputWrapper(FeedfreeInput):
self.get_stage_op(), self.get_unstage_op(), self._nr_stage)) self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
def setup_staging_areas(self): def setup_staging_areas(self):
logger.info("Setting up the StageAreas for GPU prefetching ...") logger.info("Setting up StagingArea for GPU prefetching ...")
for idx, device in enumerate(self._devices): for idx, device in enumerate(self._devices):
with tf.device(device): with tf.device(device):
inputs = self._input.get_input_tensors() inputs = self._input.get_input_tensors()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment