Commit d66d7761 authored by Yuxin Wu's avatar Yuxin Wu

better lmdb processing

parent 6aa8ab20
......@@ -15,19 +15,13 @@ from ..base import DataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12']
try:
import lmdb
except ImportError:
logger.warn("Error in 'import lmdb'. ILSVRC12CaffeLMDB won't be available.")
else:
__all__.append('ILSVRC12CaffeLMDB')
@memoized
def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
# TODO move caffe_pb outside
class ILSVRCMeta(object):
"""
Some metadata for ILSVRC dataset.
......@@ -159,60 +153,6 @@ class ILSVRC12(DataFlow):
im = np.expand_dims(im, 2).repeat(3,2)
yield [im, tp[1]]
# TODO more generally, just CaffeLMDB
class ILSVRC12CaffeLMDB(DataFlow):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_dir, shuffle=True):
"""
:param shuffle: about 3 times slower
"""
self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._meta = ILSVRCMeta()
self._shuffle = shuffle
self.rng = get_rng(self)
self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries']
if shuffle:
with timed_operation("Loading LMDB keys ...", log_start=True):
self.keys = [k for k, _ in self._txn.cursor()]
def reset_state(self):
self._txn = self._lmdb.begin()
self.rng = get_rng(self)
def size(self):
return self._size
def get_data(self):
import imp
cpb = imp.load_source('cpb', self._meta.caffe_pb_file)
datum = cpb.Datum()
def parse(k, v):
try:
datum.ParseFromString(v)
img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width)
except Exception:
log_once("Cannot read key {}".format(k))
return None
return [img.transpose(1, 2, 0), datum.label]
if not self._shuffle:
c = self._txn.cursor()
while c.next():
k, v = c.item()
v = parse(k, v)
if v: yield v
else:
s = self.size()
for i in range(s):
k = self.rng.choice(self.keys)
v = self._txn.get(k)
v = parse(k, v)
if v: yield v
if __name__ == '__main__':
meta = ILSVRCMeta()
......@@ -221,8 +161,6 @@ if __name__ == '__main__':
#ds = ILSVRC12('/home/wyx/data/imagenet', 'val')
ds = ILSVRC12CaffeLMDB('/home/yuxinwu/', True)
for k in ds.get_data():
from IPython import embed; embed()
break
......@@ -16,6 +16,13 @@ except ImportError:
else:
__all__ = ['HDF5Data']
try:
import lmdb
except ImportError:
logger.warn("Error in 'import lmdb'. LMDBData won't be available.")
else:
__all__.extend(['LMDBData', 'CaffeLMDB'])
"""
Adapters for different data format.
......@@ -49,3 +56,64 @@ class HDF5Data(DataFlow):
for k in idxs:
yield [dp[k] for dp in self.dps]
class LMDBData(DataFlow):
""" Read a lmdb and produce k,v pair """
def __init__(self, lmdb_dir, shuffle=True):
self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._shuffle = shuffle
self.rng = get_rng(self)
self._size = self._txn.stat()['entries']
if shuffle:
with timed_operation("Loading LMDB keys ...", log_start=True):
self.keys = [k for k, _ in self._txn.cursor()]
def reset_state(self):
self._txn = self._lmdb.begin()
self.rng = get_rng(self)
def size(self):
return self._size
def get_data(self):
if not self._shuffle:
c = self._txn.cursor()
while c.next():
k, v = c.item()
yield [k, v]
else:
s = self.size()
for i in range(s):
k = self.rng.choice(self.keys)
v = self._txn.get(k)
yield [k, v]
class CaffeLMDB(LMDBData):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_dir, shuffle=True):
"""
:param shuffle: about 3 times slower
"""
super(CaffeLMDB, self).__init__(lmdb_dir, shuffle)
import imp
meta = ILSVRCMeta()
self.cpb = imp.load_source('cpb', meta.caffe_pb_file)
def get_data(self):
datum = self.cpb.Datum()
def parse(k, v):
try:
datum.ParseFromString(v)
img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width)
except Exception:
log_once("Cannot read key {}".format(k))
return None
return [img.transpose(1, 2, 0), datum.label]
for dp in super(CaffeLMDB, self).get_data():
v = parse(dp[0], dp[1])
if v: yield v
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment