Commit 52bbe706 authored by Yuxin Wu's avatar Yuxin Wu

Add TinyImageNet

parent 02e53f72
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import os import os
import tarfile import tarfile
import tqdm import tqdm
from pathlib import Path
from ...utils import logger from ...utils import logger
from ...utils.fs import download, get_dataset_path, mkdir_p from ...utils.fs import download, get_dataset_path, mkdir_p
...@@ -12,14 +13,14 @@ from ...utils.loadcaffe import get_caffe_pb ...@@ -12,14 +13,14 @@ from ...utils.loadcaffe import get_caffe_pb
from ...utils.timer import timed_operation from ...utils.timer import timed_operation
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files'] __all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files', 'TinyImageNet']
CAFFE_ILSVRC12_URL = ("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", 17858008) CAFFE_ILSVRC12_URL = ("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", 17858008)
class ILSVRCMeta(object): class ILSVRCMeta(object):
""" """
Provide methods to access metadata for ILSVRC dataset. Provide methods to access metadata for :class:`ILSVRC12` dataset.
""" """
def __init__(self, dir=None): def __init__(self, dir=None):
...@@ -178,8 +179,11 @@ class ILSVRC12Files(RNGDataFlow): ...@@ -178,8 +179,11 @@ class ILSVRC12Files(RNGDataFlow):
class ILSVRC12(ILSVRC12Files): class ILSVRC12(ILSVRC12Files):
""" """
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999]. The ILSVRC12 classification dataset, aka the commonly used 1000 classes ImageNet subset.
The label map follows the synsets.txt file in http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz. This dataflow produces uint8 images of shape [h, w, 3(BGR)], and a label between [0, 999].
The label map follows the synsets.txt file in
http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz,
which can also be queried using :class:`ILSVRCMeta`.
""" """
def __init__(self, dir, name, meta_dir=None, def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure=None): shuffle=None, dir_structure=None):
...@@ -287,6 +291,64 @@ class ILSVRC12(ILSVRC12Files): ...@@ -287,6 +291,64 @@ class ILSVRC12(ILSVRC12Files):
return ret return ret
class TinyImageNet(RNGDataFlow):
"""
The TinyImageNet classification dataset, with 200 classes and 500 images
per class. See https://tiny-imagenet.herokuapp.com/.
It produces [image, label] where image is a 64x64x3(BGR) image, label is an
integer in [0, 200).
"""
def __init__(self, dir, name, shuffle=None):
"""
Args:
dir (str): a directory
name (str): one of 'train' or 'val'
shuffle (bool): shuffle the dataset.
Defaults to True if name=='train'.
"""
assert name in ['train', 'val'], name
dir = Path(os.path.expanduser(dir))
assert os.path.isdir(dir), dir
self.full_dir = dir / name
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle
with open(dir / "wnids.txt") as f:
wnids = [x.strip() for x in f.readlines()]
cls_to_id = {name: id for id, name in enumerate(wnids)}
assert len(cls_to_id) == 200
self.imglist = []
if name == 'train':
for clsid, cls in enumerate(wnids):
cls_dir = self.full_dir / cls / "images"
for img in cls_dir.iterdir():
self.imglist.append((str(img), clsid))
else:
with open(self.full_dir / "val_annotations.txt") as f:
for line in f:
line = line.strip().split()
img, cls = line[0], line[1]
img = self.full_dir / "images" / img
clsid = cls_to_id[cls]
self.imglist.append((str(img), clsid))
def __len__(self):
return len(self.imglist)
def __iter__(self):
idxs = np.arange(len(self.imglist))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self.imglist[k]
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
yield [im, label]
try: try:
import cv2 import cv2
except ImportError: except ImportError:
...@@ -297,7 +359,7 @@ if __name__ == '__main__': ...@@ -297,7 +359,7 @@ if __name__ == '__main__':
meta = ILSVRCMeta() meta = ILSVRCMeta()
# print(meta.get_synset_words_1000()) # print(meta.get_synset_words_1000())
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False) ds = TinyImageNet('~/data/tiny-imagenet-200', 'val', shuffle=False)
ds.reset_state() ds.reset_state()
for _ in ds: for _ in ds:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment