Commit 12a7b7ff authored by Yuxin Wu's avatar Yuxin Wu

caffe lmdb dataflow

parent 8532e89d
......@@ -6,13 +6,25 @@ import os
import tarfile
import cv2
import numpy as np
from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir
from ...utils import logger, get_rng, get_dataset_dir, memoized
from ...utils.timer import timed_operation
from ...utils.fs import mkdir_p, download
from ..base import DataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12']
try:
import lmdb
except ImportError:
logger.warn("Error in 'import lmdb'. ILSVRC12CaffeLMDB won't be available.")
else:
__all__.append('ILSVRC12CaffeLMDB')
@memoized
def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
......@@ -123,11 +135,67 @@ class ILSVRC12(DataFlow):
im = np.expand_dims(im, 2).repeat(3,2)
yield [im, tp[1]]
class ILSVRC12CaffeLMDB(DataFlow):
def __init__(self, lmdb_dir, shuffle=True):
"""
:param shuffle: about 3 times slower
"""
self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._meta = ILSVRCMeta()
self._shuffle = shuffle
self.rng = get_rng(self)
if shuffle:
with timed_operation("Loading LMDB keys ..."):
self.keys = [k for k, _ in self._lmdb.begin().cursor()]
def reset_state(self):
self._txn = self._lmdb.begin()
self.rng = get_rng(self)
def size(self):
return self._txn.stat()['entries']
def get_data(self):
import imp
cpb = imp.load_source('cpb', self._meta.caffe_pb_file)
datum = cpb.Datum()
def parse(k, v):
try:
datum.ParseFromString(v)
img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width)
except Exception:
log_once("Cannot read key {}".format(k))
return None
return [img.transpose(1, 2, 0), datum.label]
with self._txn:
if not self._shuffle:
c = self._txn.cursor()
while c.next():
k, v = c.item()
v = parse(k, v)
if v: yield v
else:
s = self.size()
for i in range(s):
k = self.rng.choice(self.keys)
v = self._txn.get(k)
v = parse(k, v)
if v: yield v
if __name__ == '__main__':
meta = ILSVRCMeta()
print(meta.get_per_pixel_mean())
#print(meta.get_synset_words_1000())
#ds = ILSVRC12('/home/wyx/data/imagenet', 'val')
#for k in ds.get_data():
#from IPython import embed; embed()
ds = ILSVRC12CaffeLMDB('/home/yuxinwu/', True)
for k in ds.get_data():
from IPython import embed; embed()
break
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment