Commit 42322257 authored by Yuxin Wu's avatar Yuxin Wu

add an ILSVRC12Files dataset (#139)

parent 43b04fc8
...@@ -4,10 +4,8 @@ ...@@ -4,10 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
import tarfile import tarfile
import six
import numpy as np import numpy as np
import tqdm import tqdm
import xml.etree.ElementTree as ET
from ...utils import logger from ...utils import logger
from ...utils.loadcaffe import get_caffe_pb from ...utils.loadcaffe import get_caffe_pb
...@@ -15,7 +13,7 @@ from ...utils.fs import mkdir_p, download, get_dataset_path ...@@ -15,7 +13,7 @@ from ...utils.fs import mkdir_p, download, get_dataset_path
from ...utils.timer import timed_operation from ...utils.timer import timed_operation
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12'] __all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files']
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
...@@ -107,13 +105,48 @@ class ILSVRCMeta(object): ...@@ -107,13 +105,48 @@ class ILSVRCMeta(object):
return arr return arr
class ILSVRC12(RNGDataFlow): class ILSVRC12Files(RNGDataFlow):
"""
Same as :class:`ILSVRC12`, but produces filenames of the images instead of nparrays.
This could be useful when ``cv2.imread`` is a bottleneck and you want to
decode it in smarter ways (e.g. in parallel).
"""
def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'):
"""
Same as in :class:`ILSVRC12`.
"""
assert name in ['train', 'test', 'val'], name
assert os.path.isdir(dir), dir
self.full_dir = os.path.join(dir, name)
self.name = name
assert os.path.isdir(self.full_dir), self.full_dir
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle
meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name, dir_structure)
def size(self):
return len(self.imglist)
def get_data(self):
idxs = np.arange(len(self.imglist))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self.imglist[k]
fname = os.path.join(self.full_dir, fname)
yield [fname, label]
class ILSVRC12(ILSVRC12Files):
""" """
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999], Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999],
and optionally a bounding box of [xmin, ymin, xmax, ymax]. and optionally a bounding box of [xmin, ymin, xmax, ymax].
""" """
def __init__(self, dir, name, meta_dir=None, shuffle=None, def __init__(self, dir, name, meta_dir=None,
dir_structure='original', include_bb=False): shuffle=None, dir_structure='original'):
""" """
Args: Args:
dir (str): A directory containing a subdir named ``name``, where the dir (str): A directory containing a subdir named ``name``, where the
...@@ -126,7 +159,6 @@ class ILSVRC12(RNGDataFlow): ...@@ -126,7 +159,6 @@ class ILSVRC12(RNGDataFlow):
directory, which only has list of image files (as below). directory, which only has list of image files (as below).
If set to 'train', it expects the same two-level If set to 'train', it expects the same two-level
directory structure simlar to 'train/'. directory structure simlar to 'train/'.
include_bb (bool): Include the bounding box. Maybe useful in training.
Examples: Examples:
...@@ -157,50 +189,18 @@ class ILSVRC12(RNGDataFlow): ...@@ -157,50 +189,18 @@ class ILSVRC12(RNGDataFlow):
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
""" """
assert name in ['train', 'test', 'val'], name super(ILSVRC12, self).__init__(
assert os.path.isdir(dir), dir dir, name, meta_dir, shuffle, dir_structure)
self.full_dir = os.path.join(dir, name)
self.name = name
assert os.path.isdir(self.full_dir), self.full_dir
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle
meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name, dir_structure)
self.synset = meta.get_synset_1000()
if include_bb:
bbdir = os.path.join(dir, 'bbox') if not \
isinstance(include_bb, six.string_types) else include_bb
assert name == 'train', 'Bounding box only available for training'
self.bblist = ILSVRC12.get_training_bbox(bbdir, self.imglist)
self.include_bb = include_bb
def size(self):
return len(self.imglist)
def get_data(self): def get_data(self):
idxs = np.arange(len(self.imglist)) for fname, label in super(ILSVRC12, self).get_data():
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self.imglist[k]
fname = os.path.join(self.full_dir, fname)
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
if im.ndim == 2: yield [im, label]
im = np.expand_dims(im, 2).repeat(3, 2)
if self.include_bb:
bb = self.bblist[k]
if bb is None:
bb = [0, 0, im.shape[1] - 1, im.shape[0] - 1]
yield [im, label, bb]
else:
yield [im, label]
@staticmethod @staticmethod
def get_training_bbox(bbox_dir, imglist): def get_training_bbox(bbox_dir, imglist):
import xml.etree.ElementTree as ET
ret = [] ret = []
def parse_bbox(fname): def parse_bbox(fname):
...@@ -239,8 +239,7 @@ if __name__ == '__main__': ...@@ -239,8 +239,7 @@ if __name__ == '__main__':
meta = ILSVRCMeta() meta = ILSVRCMeta()
# print(meta.get_synset_words_1000()) # print(meta.get_synset_words_1000())
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', include_bb=True, ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False)
shuffle=False)
ds.reset_state() ds.reset_state()
for k in ds.get_data(): for k in ds.get_data():
......
...@@ -158,8 +158,9 @@ class LMDBDataDecoder(MapData): ...@@ -158,8 +158,9 @@ class LMDBDataDecoder(MapData):
class LMDBDataPoint(MapData): class LMDBDataPoint(MapData):
""" """
Read a LMDB file and produce deserialized datapoints. Read a LMDB file and produce deserialized datapoints.
It reads the database produced by It only accepts the database produced by
:func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`,
which uses :func:`tensorpack.utils.serialize.dumps` for serialization.
Example: Example:
.. code-block:: python .. code-block:: python
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment