Commit 70e14a6b authored by Yuxin Wu's avatar Yuxin Wu

dump and load dataflow with lmdb

parent 4fcb47f3
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: imgclassify.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import tensorflow as tf
import imp
from tensorpack.utils import *
from tensorpack.utils import sessinit
from tensorpack.dataflow import *
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument(dest='model')
parser.add_argument(dest='images', nargs='+')
parser.add_argument('--output_type', default='label',
choices=['label', 'label-prob', 'raw'])
parser.add_argument('--top', default=1, type=int)
args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
# TODO not sure if it this script is still working
with tf.Graph().as_default() as G:
train_config = get_config_func()
M = train_config.model
config = PredictConfig(
input_var_names=[M.get_input_vars_desc()[0].name], # assume first component is image
model=M,
session_init=sessinit.SaverRestore(args.model),
output_var_names=['output:0']
)
ds = ImageFromFile(args.images, 3, resize=(227, 227))
ds = BatchData(ds, 128, remainder=True)
predictor = SimpleDatasetPredictor(config, ds)
res = predictor.get_all_result()
if args.output_type == 'label':
for r in res:
print r[0].argsort(axis=1)[:,-args.top:][:,::-1]
elif args.output_type == 'label_prob':
raise NotImplementedError
elif args.output_type == 'raw':
print res
......@@ -96,7 +96,13 @@ class InferenceRunner(Callback):
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.get_input_vars()
self.input_tensors = [x.name for x in input_vars]
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def get_name(x):
if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0]
return x.name
self.input_tensors = [get_name(x) for x in input_vars]
def _find_output_tensors(self):
dispatcer = OutputTensorDispatcer()
......
......@@ -5,11 +5,21 @@
import sys, os
import cv2
import multiprocessing as mp
import six
from six.moves import range, map
from ..utils import get_tqdm, logger
from ..utils.concurrency import DIE
from ..utils.serialize import dumps
from ..utils.fs import mkdir_p
__all__ = ['dump_dataset_images', 'dataflow_to_process_queue']
try:
import lmdb
except ImportError:
logger.warn_dependency("dump_dataflow_to_lmdb", 'lmdb')
else:
__all__.extend(['dump_dataflow_to_lmdb'])
# TODO pass a name_func to write label as filename?
def dump_dataset_images(ds, dirname, max_count=None, index=0):
......@@ -32,6 +42,28 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
img = dp[index]
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_lmdb(ds, lmdb_path):
isdir = os.path.isdir(lmdb_path)
if isdir:
assert not os.path.isfile(os.path.join(lmdb_path, 'data.mdb')), "LMDB file exists!"
else:
assert not os.path.isfile(lmdb_path), "LMDB file exists!"
ds.reset_state()
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True) # need sync() at the end
with get_tqdm(total=ds.size()) as pbar:
with db.begin(write=True) as txn:
for idx, dp in enumerate(ds.get_data()):
txn.put(six.binary_type(idx), dumps(dp))
pbar.update()
keys = list(map(six.binary_type, range(idx + 1)))
txn.put('__keys__', dumps(keys))
logger.info("Flushing database ...")
db.sync()
db.close()
def dataflow_to_process_queue(ds, size, nr_consumer):
"""
Convert a `DataFlow` to a multiprocessing.Queue.
......
......@@ -4,10 +4,12 @@
import numpy as np
from six.moves import range
import os
from ..utils import logger, get_rng, get_tqdm
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from ..utils.serialize import loads
from .base import RNGDataFlow
try:
......@@ -23,7 +25,7 @@ try:
except ImportError:
logger.warn_dependency("LMDBData", 'lmdb')
else:
__all__.extend(['LMDBData', 'CaffeLMDB', 'LMDBDataDecoder'])
__all__.extend(['LMDBData', 'CaffeLMDB', 'LMDBDataDecoder', 'LMDBDataPoint'])
try:
import sklearn.datasets
......@@ -69,15 +71,16 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow):
""" Read a lmdb and produce k,v pair """
def __init__(self, lmdb_dir, shuffle=True):
self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False,
def __init__(self, lmdb_path, shuffle=True):
self._lmdb = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path),
readonly=True, lock=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._shuffle = shuffle
self._size = self._txn.stat()['entries']
if shuffle:
# get the list of keys either from __keys__ or by iterating
self.keys = self._txn.get('__keys__')
self.keys = loads(self._txn.get('__keys__'))
if not self.keys:
self.keys = []
with timed_operation("Loading LMDB keys ...", log_start=True), \
......@@ -109,12 +112,12 @@ class LMDBData(RNGDataFlow):
yield [k, v]
class LMDBDataDecoder(LMDBData):
def __init__(self, lmdb_dir, decoder, shuffle=True):
def __init__(self, lmdb_path, decoder, shuffle=True):
"""
:param decoder: a function taking k, v and return a data point,
or return None to skip
"""
super(LMDBDataDecoder, self).__init__(lmdb_dir, shuffle)
super(LMDBDataDecoder, self).__init__(lmdb_path, shuffle)
self.decoder = decoder
def get_data(self):
......@@ -122,9 +125,15 @@ class LMDBDataDecoder(LMDBData):
v = self.decoder(dp[0], dp[1])
if v: yield v
class LMDBDataPoint(LMDBDataDecoder):
""" Read a LMDB file where each value is a serialized datapoint"""
def __init__(self, lmdb_path, shuffle=True):
super(SimpleLMDBLoader, self).__init__(
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle)
class CaffeLMDB(LMDBDataDecoder):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_dir, shuffle=True):
def __init__(self, lmdb_path, shuffle=True):
cpb = get_caffe_pb()
def decoder(k, v):
try:
......@@ -138,7 +147,7 @@ class CaffeLMDB(LMDBDataDecoder):
return [img.transpose(1, 2, 0), datum.label]
super(CaffeLMDB, self).__init__(
lmdb_dir, decoder=decoder, shuffle=shuffle)
lmdb_path, decoder=decoder, shuffle=shuffle)
class SVMLightData(RNGDataFlow):
""" Read X,y from a svmlight file """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment