Commit 7c7f6e85 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

LMDB fix (#132)

* LMDB fix

LMDB was slow (>35min for ILSVRC2012_img_train.lmdb) when iterating
over the entries to find the keys. Now, it is possible to specify
the key_format, such that there is no overhead anymore.

Changes:
- use key_format to generate keys
- remove lmdb_open in reset-state (was open db twice)
- msgpack raised an exception, which is now catched

* use six to check stringtype

* rename key_format argument in lmdb and proper documentation

* small changes for prefetching

* some docs
parent 1c4166f7
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
import six
from six.moves import range from six.moves import range
import os import os
...@@ -59,38 +60,61 @@ class HDF5Data(RNGDataFlow): ...@@ -59,38 +60,61 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow): class LMDBData(RNGDataFlow):
""" Read a LMDB database and produce (k,v) pairs """ """ Read a LMDB database and produce (k,v) pairs """
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True, keys=None):
""" """
Args: Args:
lmdb_path (str): a directory or a file. lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not. shuffle (bool): shuffle the keys or not.
keys (list of str or str): list of str as the keys, used only when shuffle is True.
It can also be a format string e.g. `'{:0>8d}'` which will be
formatted with the indices from 0 to `total_size - 1`.
If not provided, it will then look in the database for `__keys__` which
:func:`dump_dataflow_to_lmdb` used to store the list of keys.
If still not found, it will iterate over the database to find
all the keys.
""" """
self._lmdb_path = lmdb_path self._lmdb_path = lmdb_path
self._shuffle = shuffle self._shuffle = shuffle
self.open_lmdb() self.keys = keys
self.open_lmdb(keys)
def open_lmdb(self, keys=None):
def find_keys(txn, size):
logger.warn("Traversing the database to find keys is slow. Your should specify the keys.")
keys = []
with timed_operation("Loading LMDB keys ...", log_start=True), \
get_tqdm(total=size) as pbar:
for k in self._txn.cursor():
assert k[0] != '__keys__'
keys.append(k[0])
pbar.update()
return keys
def open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path, self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path), subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False, readonly=True, lock=False, readahead=False,
map_size=1099511627776 * 2, max_readers=100) map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin() self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries'] self._size = self._txn.stat()['entries']
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
if self._shuffle: if self._shuffle:
# get the list of keys either from __keys__ or by iterating if keys is None:
self.keys = loads(self._txn.get('__keys__')) # get the list of keys either from __keys__ or by iterating
if not self.keys: try:
self.keys = [] self.keys = loads(self._txn.get('__keys__'))
with timed_operation("Loading LMDB keys ...", log_start=True), \ except Exception:
get_tqdm(total=self._size) as pbar: self.keys = find_keys(self._txn, self._size)
for k in self._txn.cursor(): else:
if k[0] != '__keys__': # check if key-format like '{:0>8d}' was given
self.keys.append(k[0]) if isinstance(keys, six.string_types):
pbar.update() self.keys = map(lambda x: keys.format(x), list(np.arange(self._size)))
def reset_state(self): def reset_state(self):
super(LMDBData, self).reset_state() super(LMDBData, self).reset_state()
self.open_lmdb() self.open_lmdb(self.keys)
def size(self): def size(self):
return self._size return self._size
...@@ -111,15 +135,14 @@ class LMDBData(RNGDataFlow): ...@@ -111,15 +135,14 @@ class LMDBData(RNGDataFlow):
class LMDBDataDecoder(LMDBData): class LMDBDataDecoder(LMDBData):
""" Read a LMDB database and produce a decoded output.""" """ Read a LMDB database and produce a decoded output."""
def __init__(self, lmdb_path, decoder, shuffle=True): def __init__(self, lmdb_path, decoder, shuffle=True, keys=None):
""" """
Args: Args:
lmdb_path (str): a directory or a file. lmdb_path, shuffle, keys: same as :class:`LMDBData`.
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint, decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
or return None to discard. or return None to discard.
shuffle (bool): shuffle the keys or not.
""" """
super(LMDBDataDecoder, self).__init__(lmdb_path, shuffle) super(LMDBDataDecoder, self).__init__(lmdb_path, shuffle=shuffle, keys=keys)
self.decoder = decoder self.decoder = decoder
def get_data(self): def get_data(self):
...@@ -133,27 +156,30 @@ class LMDBDataPoint(LMDBDataDecoder): ...@@ -133,27 +156,30 @@ class LMDBDataPoint(LMDBDataDecoder):
""" Read a LMDB file and produce deserialized values. """ Read a LMDB file and produce deserialized values.
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """ This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True, keys=None):
""" """
Args: Args:
lmdb_path (str): a directory or a file. lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not. shuffle (bool): shuffle the keys or not.
keys (list): list of keys for lmdb file or the key format `'{:0>8d}'`
""" """
super(LMDBDataPoint, self).__init__( super(LMDBDataPoint, self).__init__(
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle) lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle, keys=keys)
class CaffeLMDB(LMDBDataDecoder): class CaffeLMDB(LMDBDataDecoder):
""" """
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf. Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label]. Produces datapoints of the format: [HWC image, label].
Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
""" """
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True, keys=None):
""" """
Args: Args:
lmdb_path (str): a directory or a file. lmdb_path, shuffle, keys: same as :class:`LMDBData`.
shuffle (bool): shuffle the keys or not.
""" """
cpb = get_caffe_pb() cpb = get_caffe_pb()
...@@ -169,7 +195,7 @@ class CaffeLMDB(LMDBDataDecoder): ...@@ -169,7 +195,7 @@ class CaffeLMDB(LMDBDataDecoder):
return [img.transpose(1, 2, 0), datum.label] return [img.transpose(1, 2, 0), datum.label]
super(CaffeLMDB, self).__init__( super(CaffeLMDB, self).__init__(
lmdb_path, decoder=decoder, shuffle=shuffle) lmdb_path, decoder=decoder, shuffle=shuffle, keys=keys)
class SVMLightData(RNGDataFlow): class SVMLightData(RNGDataFlow):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment