Commit 52bbe706 authored by Yuxin Wu's avatar Yuxin Wu

Add TinyImageNet

parent 02e53f72
......@@ -5,6 +5,7 @@ import numpy as np
import os
import tarfile
import tqdm
from pathlib import Path
from ...utils import logger
from ...utils.fs import download, get_dataset_path, mkdir_p
......@@ -12,14 +13,14 @@ from ...utils.loadcaffe import get_caffe_pb
from ...utils.timer import timed_operation
from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files']
__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files', 'TinyImageNet']
CAFFE_ILSVRC12_URL = ("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", 17858008)
class ILSVRCMeta(object):
"""
Provide methods to access metadata for ILSVRC dataset.
Provide methods to access metadata for :class:`ILSVRC12` dataset.
"""
def __init__(self, dir=None):
......@@ -178,8 +179,11 @@ class ILSVRC12Files(RNGDataFlow):
class ILSVRC12(ILSVRC12Files):
"""
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999].
The label map follows the synsets.txt file in http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz.
The ILSVRC12 classification dataset, aka the commonly used 1000 classes ImageNet subset.
This dataflow produces uint8 images of shape [h, w, 3(BGR)], and a label between [0, 999].
The label map follows the synsets.txt file in
http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz,
which can also be queried using :class:`ILSVRCMeta`.
"""
def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure=None):
......@@ -287,6 +291,64 @@ class ILSVRC12(ILSVRC12Files):
return ret
class TinyImageNet(RNGDataFlow):
"""
The TinyImageNet classification dataset, with 200 classes and 500 images
per class. See https://tiny-imagenet.herokuapp.com/.
It produces [image, label] where image is a 64x64x3(BGR) image, label is an
integer in [0, 200).
"""
def __init__(self, dir, name, shuffle=None):
"""
Args:
dir (str): a directory
name (str): one of 'train' or 'val'
shuffle (bool): shuffle the dataset.
Defaults to True if name=='train'.
"""
assert name in ['train', 'val'], name
dir = Path(os.path.expanduser(dir))
assert os.path.isdir(dir), dir
self.full_dir = dir / name
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle
with open(dir / "wnids.txt") as f:
wnids = [x.strip() for x in f.readlines()]
cls_to_id = {name: id for id, name in enumerate(wnids)}
assert len(cls_to_id) == 200
self.imglist = []
if name == 'train':
for clsid, cls in enumerate(wnids):
cls_dir = self.full_dir / cls / "images"
for img in cls_dir.iterdir():
self.imglist.append((str(img), clsid))
else:
with open(self.full_dir / "val_annotations.txt") as f:
for line in f:
line = line.strip().split()
img, cls = line[0], line[1]
img = self.full_dir / "images" / img
clsid = cls_to_id[cls]
self.imglist.append((str(img), clsid))
def __len__(self):
return len(self.imglist)
def __iter__(self):
idxs = np.arange(len(self.imglist))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self.imglist[k]
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
yield [im, label]
try:
import cv2
except ImportError:
......@@ -297,7 +359,7 @@ if __name__ == '__main__':
meta = ILSVRCMeta()
# print(meta.get_synset_words_1000())
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False)
ds = TinyImageNet('~/data/tiny-imagenet-200', 'val', shuffle=False)
ds.reset_state()
for _ in ds:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment