Commit d5143723 authored by Yuxin Wu's avatar Yuxin Wu

lmdb shouldn't reload keys

parent 00c47fa0
...@@ -77,10 +77,12 @@ class LMDBData(RNGDataFlow): ...@@ -77,10 +77,12 @@ class LMDBData(RNGDataFlow):
""" """
self._lmdb_path = lmdb_path self._lmdb_path = lmdb_path
self._shuffle = shuffle self._shuffle = shuffle
self.keys = keys
self.open_lmdb(keys)
def open_lmdb(self, keys=None): self.open_lmdb()
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
self._set_keys(keys)
def _set_keys(self, keys=None):
def find_keys(txn, size): def find_keys(txn, size):
logger.warn("Traversing the database to find keys is slow. Your should specify the keys.") logger.warn("Traversing the database to find keys is slow. Your should specify the keys.")
keys = [] keys = []
...@@ -92,15 +94,6 @@ class LMDBData(RNGDataFlow): ...@@ -92,15 +94,6 @@ class LMDBData(RNGDataFlow):
pbar.update() pbar.update()
return keys return keys
self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries']
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
if self._shuffle: if self._shuffle:
if keys is None: if keys is None:
# get the list of keys either from __keys__ or by iterating # get the list of keys either from __keys__ or by iterating
...@@ -112,10 +105,21 @@ class LMDBData(RNGDataFlow): ...@@ -112,10 +105,21 @@ class LMDBData(RNGDataFlow):
# check if key-format like '{:0>8d}' was given # check if key-format like '{:0>8d}' was given
if isinstance(keys, six.string_types): if isinstance(keys, six.string_types):
self.keys = map(lambda x: keys.format(x), list(np.arange(self._size))) self.keys = map(lambda x: keys.format(x), list(np.arange(self._size)))
else:
self.keys = keys
def open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries']
def reset_state(self): def reset_state(self):
self._lmdb.close()
super(LMDBData, self).reset_state() super(LMDBData, self).reset_state()
self.open_lmdb(self.keys) self.open_lmdb()
def size(self): def size(self):
return self._size return self._size
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment