Commit 0a0b387e authored by Yuxin Wu's avatar Yuxin Wu

Make LMDBData usable under spawn (fix #1219)

parent 4df295ca
......@@ -91,6 +91,11 @@ class LMDBData(RNGDataFlow):
self._set_keys(keys)
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
# Clean them up after finding the list of keys, since we don't want to fork them
self._lmdb.close()
del self._lmdb
del self._txn
def _set_keys(self, keys=None):
def find_keys(txn, size):
logger.warn("Traversing the database to find keys is slow. Your should specify the keys.")
......@@ -128,9 +133,8 @@ class LMDBData(RNGDataFlow):
def reset_state(self):
self._guard = DataFlowReentrantGuard()
self._lmdb.close()
super(LMDBData, self).reset_state()
self._open_lmdb()
self._open_lmdb() # open the LMDB in the worker process
def __len__(self):
return self._size
......
......@@ -41,6 +41,7 @@ class LMDBSerializer():
df (DataFlow): the DataFlow to serialize.
path (str): output path. Either a directory or an lmdb file.
write_frequency (int): the frequency to write back data to disk.
A smaller value reduces memory usage.
"""
assert isinstance(df, DataFlow), type(df)
isdir = os.path.isdir(path)
......@@ -103,7 +104,11 @@ class LMDBSerializer():
and run deserialization as a mapper in parallel.
"""
df = LMDBData(path, shuffle=shuffle)
return MapData(df, lambda dp: loads(dp[1]))
return MapData(df, LMDBSerializer._deserialize_lmdb)
@staticmethod
def _deserialize_lmdb(dp):
return loads(dp[1])
class NumpySerializer():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment