Commit 836e2019 authored by Yuxin Wu's avatar Yuxin Wu

fix LMDB transaction

parent 893c89dd
......@@ -43,6 +43,8 @@ with argscope(Conv2D, out_channel=32, kernel_shape=3, nl=tf.nn.relu):
.Conv2D('conv2', kernel_shape=5)
.FullyConnected('fc0', 512, nl=tf.nn.relu)
.Dropout('dropout', 0.5)
.tf.multiply(0.5)
.apply(func, *args, **kwargs)
.FullyConnected('fc1', out_dim=10, nl=tf.identity)())
```
is equivalent to:
......@@ -53,6 +55,8 @@ l = Conv2D('conv1', l, 32, 3, padding='SAME', nl=tf.nn.relu)
l = Conv2D('conv2', l, 32, 5, nl=tf.nn.relu)
l = FullyConnected('fc0', l, 512, nl=tf.nn.relu)
l = Dropout('dropout', l, 0.5)
l = tf.multiply(l, 0.5)
l = func(l, *args, **kwargs)
l = FullyConnected('fc1', l, 10, nl=tf.identity)
```
......
......@@ -69,18 +69,18 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
sz = 0
with get_tqdm(total=sz) as pbar:
idx = -1
itr = ds.get_data()
try:
while True:
with db.begin(write=True) as txn:
for _ in range(write_frequency):
idx += 1
dp = next(itr)
txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp))
pbar.update()
except StopIteration:
pass
# lmdb transaction is not exception-safe!
# although it has a contextmanager interface
txn = db.begin(write=True)
for idx, dp in enumerate(ds.get_data()):
txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp))
pbar.update()
if (idx + 1) % write_frequency == 0:
txn.commit()
txn = db.begin(write=True)
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps(keys))
......
......@@ -30,6 +30,8 @@ def replace_get_variable(fn):
old_vars_getv = variable_scope.get_variable
tf.get_variable = fn
# doesn't seem to be working?
# and when it works, remap might call fn twice
variable_scope.get_variable = fn
yield
tf.get_variable = old_getv
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment