Commit a027b8de authored by Yuxin Wu's avatar Yuxin Wu

try a different way handling lmdb size

parent 13e3c39a
......@@ -144,7 +144,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 20, 5)
ds = PrefetchDataZMQ(ds, 5)
return ds
......
......@@ -146,16 +146,18 @@ class ILSVRC12CaffeLMDB(DataFlow):
self._meta = ILSVRCMeta()
self._shuffle = shuffle
self.rng = get_rng(self)
self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries']
if shuffle:
with timed_operation("Loading LMDB keys ..."):
self.keys = [k for k, _ in self._lmdb.begin().cursor()]
self.keys = [k for k, _ in self._txn.cursor()]
def reset_state(self):
self._txn = self._lmdb.begin()
self.rng = get_rng(self)
def size(self):
return self._txn.stat()['entries']
return self._size
def get_data(self):
import imp
......
......@@ -48,14 +48,14 @@ class LinearWrap(object):
def __getattr__(self, layer_name):
layer = eval(layer_name)
if hasattr(layer, 'f'):
# a registered tensorpack layer
# this is a registered tensorpack layer
def f(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret)
return f
else:
if layer_name != 'tf':
logger.warn("You're calling LinearWrap.__getattr__ with something neither a layer nor 'tf'. Not officially supported yet!")
logger.warn("You're calling LinearWrap.__getattr__ with something neither a layer nor 'tf'!")
assert isinstance(layer, ModuleType)
return LinearWrap.TFModuleFunc(layer, self._t)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment