Commit 0a0b387e authored by Yuxin Wu's avatar Yuxin Wu

Make LMDBData usable under spawn (fix #1219)

parent 4df295ca
...@@ -91,6 +91,11 @@ class LMDBData(RNGDataFlow): ...@@ -91,6 +91,11 @@ class LMDBData(RNGDataFlow):
self._set_keys(keys) self._set_keys(keys)
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path)) logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
# Clean them up after finding the list of keys, since we don't want to fork them
self._lmdb.close()
del self._lmdb
del self._txn
def _set_keys(self, keys=None): def _set_keys(self, keys=None):
def find_keys(txn, size): def find_keys(txn, size):
logger.warn("Traversing the database to find keys is slow. Your should specify the keys.") logger.warn("Traversing the database to find keys is slow. Your should specify the keys.")
...@@ -128,9 +133,8 @@ class LMDBData(RNGDataFlow): ...@@ -128,9 +133,8 @@ class LMDBData(RNGDataFlow):
def reset_state(self): def reset_state(self):
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
self._lmdb.close()
super(LMDBData, self).reset_state() super(LMDBData, self).reset_state()
self._open_lmdb() self._open_lmdb() # open the LMDB in the worker process
def __len__(self): def __len__(self):
return self._size return self._size
......
...@@ -41,6 +41,7 @@ class LMDBSerializer(): ...@@ -41,6 +41,7 @@ class LMDBSerializer():
df (DataFlow): the DataFlow to serialize. df (DataFlow): the DataFlow to serialize.
path (str): output path. Either a directory or an lmdb file. path (str): output path. Either a directory or an lmdb file.
write_frequency (int): the frequency to write back data to disk. write_frequency (int): the frequency to write back data to disk.
A smaller value reduces memory usage.
""" """
assert isinstance(df, DataFlow), type(df) assert isinstance(df, DataFlow), type(df)
isdir = os.path.isdir(path) isdir = os.path.isdir(path)
...@@ -103,7 +104,11 @@ class LMDBSerializer(): ...@@ -103,7 +104,11 @@ class LMDBSerializer():
and run deserialization as a mapper in parallel. and run deserialization as a mapper in parallel.
""" """
df = LMDBData(path, shuffle=shuffle) df = LMDBData(path, shuffle=shuffle)
return MapData(df, lambda dp: loads(dp[1])) return MapData(df, LMDBSerializer._deserialize_lmdb)
@staticmethod
def _deserialize_lmdb(dp):
return loads(dp[1])
class NumpySerializer(): class NumpySerializer():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment