Commit 42322257 authored by Yuxin Wu's avatar Yuxin Wu

add an ILSVRC12Files dataset (#139)

parent 43b04fc8
......@@ -4,10 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import tarfile
import six
import numpy as np
import tqdm
import xml.etree.ElementTree as ET
from ...utils import logger
from ...utils.loadcaffe import get_caffe_pb
......@@ -15,7 +13,7 @@ from ...utils.fs import mkdir_p, download, get_dataset_path
from ...utils.timer import timed_operation
from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12']
__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files']
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
......@@ -107,13 +105,48 @@ class ILSVRCMeta(object):
return arr
class ILSVRC12(RNGDataFlow):
class ILSVRC12Files(RNGDataFlow):
"""
Same as :class:`ILSVRC12`, but produces filenames of the images instead of nparrays.
This could be useful when ``cv2.imread`` is a bottleneck and you want to
decode it in smarter ways (e.g. in parallel).
"""
def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'):
"""
Same as in :class:`ILSVRC12`.
"""
assert name in ['train', 'test', 'val'], name
assert os.path.isdir(dir), dir
self.full_dir = os.path.join(dir, name)
self.name = name
assert os.path.isdir(self.full_dir), self.full_dir
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle
meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name, dir_structure)
def size(self):
return len(self.imglist)
def get_data(self):
idxs = np.arange(len(self.imglist))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self.imglist[k]
fname = os.path.join(self.full_dir, fname)
yield [fname, label]
class ILSVRC12(ILSVRC12Files):
"""
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999],
and optionally a bounding box of [xmin, ymin, xmax, ymax].
"""
def __init__(self, dir, name, meta_dir=None, shuffle=None,
dir_structure='original', include_bb=False):
def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'):
"""
Args:
dir (str): A directory containing a subdir named ``name``, where the
......@@ -126,7 +159,6 @@ class ILSVRC12(RNGDataFlow):
directory, which only has list of image files (as below).
If set to 'train', it expects the same two-level
directory structure simlar to 'train/'.
include_bb (bool): Include the bounding box. Maybe useful in training.
Examples:
......@@ -157,50 +189,18 @@ class ILSVRC12(RNGDataFlow):
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
"""
assert name in ['train', 'test', 'val'], name
assert os.path.isdir(dir), dir
self.full_dir = os.path.join(dir, name)
self.name = name
assert os.path.isdir(self.full_dir), self.full_dir
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle
meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name, dir_structure)
self.synset = meta.get_synset_1000()
if include_bb:
bbdir = os.path.join(dir, 'bbox') if not \
isinstance(include_bb, six.string_types) else include_bb
assert name == 'train', 'Bounding box only available for training'
self.bblist = ILSVRC12.get_training_bbox(bbdir, self.imglist)
self.include_bb = include_bb
def size(self):
return len(self.imglist)
super(ILSVRC12, self).__init__(
dir, name, meta_dir, shuffle, dir_structure)
def get_data(self):
idxs = np.arange(len(self.imglist))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self.imglist[k]
fname = os.path.join(self.full_dir, fname)
for fname, label in super(ILSVRC12, self).get_data():
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3, 2)
if self.include_bb:
bb = self.bblist[k]
if bb is None:
bb = [0, 0, im.shape[1] - 1, im.shape[0] - 1]
yield [im, label, bb]
else:
yield [im, label]
@staticmethod
def get_training_bbox(bbox_dir, imglist):
import xml.etree.ElementTree as ET
ret = []
def parse_bbox(fname):
......@@ -239,8 +239,7 @@ if __name__ == '__main__':
meta = ILSVRCMeta()
# print(meta.get_synset_words_1000())
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', include_bb=True,
shuffle=False)
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False)
ds.reset_state()
for k in ds.get_data():
......
......@@ -158,8 +158,9 @@ class LMDBDataDecoder(MapData):
class LMDBDataPoint(MapData):
"""
Read a LMDB file and produce deserialized datapoints.
It reads the database produced by
:func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`.
It only accepts the database produced by
:func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`,
which uses :func:`tensorpack.utils.serialize.dumps` for serialization.
Example:
.. code-block:: python
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment