Commit 9b318943 authored by Yuxin Wu's avatar Yuxin Wu

grow lmdb map_size (fix #1209)

parent 413059b1
......@@ -20,7 +20,8 @@ KL = keras.layers
This is an mnist example demonstrating how to use Keras symbolic function inside tensorpack.
This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack.
Note: this example does not work for replicated-style data-parallel trainers.
Note: this example does not work for replicated-style data-parallel trainers, so may be less efficient
for some models.
"""
IMAGE_SIZE = 28
......
......@@ -3,6 +3,7 @@
import numpy as np
import os
import platform
from collections import defaultdict
from ..utils import logger
......@@ -47,10 +48,31 @@ class LMDBSerializer():
assert not os.path.isfile(os.path.join(path, 'data.mdb')), "LMDB file exists!"
else:
assert not os.path.isfile(path), "LMDB file {} exists!".format(path)
# It's OK to use super large map_size on Linux, but not on other platforms
# See: https://github.com/NVIDIA/DIGITS/issues/206
map_size = 1099511627776 * 2 if platform.system() == 'Linux' else 128 * 10**6
db = lmdb.open(path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
map_size=map_size, readonly=False,
meminit=False, map_async=True) # need sync() at the end
size = _reset_df_and_get_size(df)
# put data into lmdb, and doubling the size if full.
# Ref: https://github.com/NVIDIA/DIGITS/pull/209/files
def put_or_grow(txn, key, value):
try:
txn.put(key, value)
return txn
except lmdb.MapFullError:
pass
txn.abort()
curr_size = db.info()['map_size']
new_size = curr_size * 2
logger.info("Doubling LMDB map_size to {:.2f}GB".format(new_size / 10**9))
db.set_mapsize(new_size)
txn = db.begin(write=True)
txn = put_or_grow(txn, key, value)
return txn
with get_tqdm(total=size) as pbar:
idx = -1
......@@ -58,7 +80,7 @@ class LMDBSerializer():
# although it has a context manager interface
txn = db.begin(write=True)
for idx, dp in enumerate(df):
txn.put(u'{:08}'.format(idx).encode('ascii'), dumps(dp))
txn = put_or_grow(txn, u'{:08}'.format(idx).encode('ascii'), dumps(dp))
pbar.update()
if (idx + 1) % write_frequency == 0:
txn.commit()
......@@ -67,7 +89,7 @@ class LMDBSerializer():
keys = [u'{:08}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps(keys))
txn = put_or_grow(txn, b'__keys__', dumps(keys))
logger.info("Flushing database ...")
db.sync()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment