Commit a027b8de authored by Yuxin Wu's avatar Yuxin Wu

try a different way handling lmdb size

parent 13e3c39a
...@@ -144,7 +144,7 @@ def get_data(train_or_test): ...@@ -144,7 +144,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchData(ds, 20, 5) ds = PrefetchDataZMQ(ds, 5)
return ds return ds
......
...@@ -146,16 +146,18 @@ class ILSVRC12CaffeLMDB(DataFlow): ...@@ -146,16 +146,18 @@ class ILSVRC12CaffeLMDB(DataFlow):
self._meta = ILSVRCMeta() self._meta = ILSVRCMeta()
self._shuffle = shuffle self._shuffle = shuffle
self.rng = get_rng(self) self.rng = get_rng(self)
self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries']
if shuffle: if shuffle:
with timed_operation("Loading LMDB keys ..."): with timed_operation("Loading LMDB keys ..."):
self.keys = [k for k, _ in self._lmdb.begin().cursor()] self.keys = [k for k, _ in self._txn.cursor()]
def reset_state(self): def reset_state(self):
self._txn = self._lmdb.begin() self._txn = self._lmdb.begin()
self.rng = get_rng(self) self.rng = get_rng(self)
def size(self): def size(self):
return self._txn.stat()['entries'] return self._size
def get_data(self): def get_data(self):
import imp import imp
......
...@@ -48,14 +48,14 @@ class LinearWrap(object): ...@@ -48,14 +48,14 @@ class LinearWrap(object):
def __getattr__(self, layer_name): def __getattr__(self, layer_name):
layer = eval(layer_name) layer = eval(layer_name)
if hasattr(layer, 'f'): if hasattr(layer, 'f'):
# a registered tensorpack layer # this is a registered tensorpack layer
def f(name, *args, **kwargs): def f(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs) ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret) return LinearWrap(ret)
return f return f
else: else:
if layer_name != 'tf': if layer_name != 'tf':
logger.warn("You're calling LinearWrap.__getattr__ with something neither a layer nor 'tf'. Not officially supported yet!") logger.warn("You're calling LinearWrap.__getattr__ with something neither a layer nor 'tf'!")
assert isinstance(layer, ModuleType) assert isinstance(layer, ModuleType)
return LinearWrap.TFModuleFunc(layer, self._t) return LinearWrap.TFModuleFunc(layer, self._t)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment