Commit 8cdc6efd authored by Yuxin Wu's avatar Yuxin Wu

lmdb __keys__

parent 0a0101d0
......@@ -68,10 +68,10 @@ class ModelSaver(Callback):
linkname = os.path.join(os.path.dirname(latest), 'latest')
try:
os.unlink(linkname)
except FileNotFoundError:
except OSError:
pass
os.symlink(basename, linkname)
except Exception: # disk error sometimes.. just ignore
except OSError, IOError: # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback):
......
......@@ -8,6 +8,7 @@ from ..utils.loadcaffe import get_caffe_pb
from .base import DataFlow
import random
from tqdm import tqdm
from six.moves import range
try:
......@@ -69,8 +70,15 @@ class LMDBData(DataFlow):
self.rng = get_rng(self)
self._size = self._txn.stat()['entries']
if shuffle:
with timed_operation("Loading LMDB keys ...", log_start=True):
self.keys = [k for k, _ in self._txn.cursor()]
self.keys = self._txn.get('__keys__')
if not self.keys:
self.keys = []
with timed_operation("Loading LMDB keys ...", log_start=True), \
tqdm(total=self._size, ascii=True) as pbar:
for k in self._txn.cursor():
if k != '__keys__':
self.keys.append(k)
pbar.update()
def reset_state(self):
self._txn = self._lmdb.begin()
......@@ -84,7 +92,8 @@ class LMDBData(DataFlow):
c = self._txn.cursor()
while c.next():
k, v = c.item()
yield [k, v]
if k != '__keys__':
yield [k, v]
else:
s = self.size()
for i in range(s):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment