Commit 836e2019 authored by Yuxin Wu's avatar Yuxin Wu

fix LMDB transaction

parent 893c89dd
...@@ -43,6 +43,8 @@ with argscope(Conv2D, out_channel=32, kernel_shape=3, nl=tf.nn.relu): ...@@ -43,6 +43,8 @@ with argscope(Conv2D, out_channel=32, kernel_shape=3, nl=tf.nn.relu):
.Conv2D('conv2', kernel_shape=5) .Conv2D('conv2', kernel_shape=5)
.FullyConnected('fc0', 512, nl=tf.nn.relu) .FullyConnected('fc0', 512, nl=tf.nn.relu)
.Dropout('dropout', 0.5) .Dropout('dropout', 0.5)
.tf.multiply(0.5)
.apply(func, *args, **kwargs)
.FullyConnected('fc1', out_dim=10, nl=tf.identity)()) .FullyConnected('fc1', out_dim=10, nl=tf.identity)())
``` ```
is equivalent to: is equivalent to:
...@@ -53,6 +55,8 @@ l = Conv2D('conv1', l, 32, 3, padding='SAME', nl=tf.nn.relu) ...@@ -53,6 +55,8 @@ l = Conv2D('conv1', l, 32, 3, padding='SAME', nl=tf.nn.relu)
l = Conv2D('conv2', l, 32, 5, nl=tf.nn.relu) l = Conv2D('conv2', l, 32, 5, nl=tf.nn.relu)
l = FullyConnected('fc0', l, 512, nl=tf.nn.relu) l = FullyConnected('fc0', l, 512, nl=tf.nn.relu)
l = Dropout('dropout', l, 0.5) l = Dropout('dropout', l, 0.5)
l = tf.multiply(l, 0.5)
l = func(l, *args, **kwargs)
l = FullyConnected('fc1', l, 10, nl=tf.identity) l = FullyConnected('fc1', l, 10, nl=tf.identity)
``` ```
......
...@@ -69,18 +69,18 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000): ...@@ -69,18 +69,18 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
sz = 0 sz = 0
with get_tqdm(total=sz) as pbar: with get_tqdm(total=sz) as pbar:
idx = -1 idx = -1
itr = ds.get_data()
# lmdb transaction is not exception-safe!
try: # although it has a contextmanager interface
while True: txn = db.begin(write=True)
with db.begin(write=True) as txn: for idx, dp in enumerate(ds.get_data()):
for _ in range(write_frequency): txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp))
idx += 1 pbar.update()
dp = next(itr) if (idx + 1) % write_frequency == 0:
txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp)) txn.commit()
pbar.update() txn = db.begin(write=True)
except StopIteration: txn.commit()
pass
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn: with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps(keys)) txn.put(b'__keys__', dumps(keys))
......
...@@ -30,6 +30,8 @@ def replace_get_variable(fn): ...@@ -30,6 +30,8 @@ def replace_get_variable(fn):
old_vars_getv = variable_scope.get_variable old_vars_getv = variable_scope.get_variable
tf.get_variable = fn tf.get_variable = fn
# doesn't seem to be working?
# and when it works, remap might call fn twice
variable_scope.get_variable = fn variable_scope.get_variable = fn
yield yield
tf.get_variable = old_getv tf.get_variable = old_getv
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment