Commit 33efcb82 authored by meak's avatar meak Committed by Yuxin Wu

feat: add Caltech101 Silouettes dataset (#1158)

* feat: add Caltech101 dataset

* feat: rename as Caltech101Silhouettes and split in train, val and test
parent 9f1a4110
......@@ -10,6 +10,7 @@ if STATICA_HACK:
from .ilsvrc import *
from .mnist import *
from .svhn import *
from .caltech101 import *
from pkgutil import iter_modules
import os
......
# -*- coding: utf-8 -*-
# File: caltech101.py
import os
from six.moves import range
from ...utils import logger
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow
__all__ = ["Caltech101Silhouettes"]
def maybe_download(url, work_directory):
"""Download the data from Marlin's website, unless it's already here."""
filename = url.split("/")[-1]
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
logger.info("Downloading to {}...".format(filepath))
download(url, work_directory)
return filepath
class Caltech101Silhouettes(RNGDataFlow):
"""
Produces [image, label] in Caltech101 Silhouettes dataset,
image is 28x28 in the range [0,1], label is an int in the range [0,100].
"""
_DIR_NAME = "caltech101_data"
_SOURCE_URL = "https://people.cs.umass.edu/~marlin/data/"
def __init__(self, name, shuffle=True, dir=None):
"""
Args:
name (str): 'train', 'test', 'val'
shuffle (bool): shuffle the dataset
"""
if dir is None:
dir = get_dataset_path(self._DIR_NAME)
assert name in ['train', 'test', 'val']
self.name = name
self.shuffle = shuffle
def get_images_and_labels(data_file):
f = maybe_download(self._SOURCE_URL + data_file, dir)
data = scipy.io.loadmat(f)
return data
self.data = get_images_and_labels("caltech101_silhouettes_28_split1.mat")
if self.train_or_test == "train":
self.images = self.data["train_data"].reshape((4100, 28, 28))
self.labels = self.data["train_labels"].ravel() - 1
elif self.name == "test":
self.images = self.data["test_data"].reshape((2307, 28, 28))
self.labels = self.data["test_labels"].ravel() - 1
else:
self.images = self.data["val_data"].reshape((2264, 28, 28))
self.labels = self.data["val_labels"].ravel() - 1
def __len__(self):
return self.images.shape[0]
def __iter__(self):
idxs = list(range(self.__len__()))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
img = self.images[k]
label = self.labels[k]
yield [img, label]
try:
import scipy.io
except ImportError:
from ...utils.develop import create_dummy_class
Caltech101Silhouettes = create_dummy_class('Caltech101Silhouettes', 'scipy.io') # noqa
if __name__ == "__main__":
ds = Caltech101Silhouettes("train")
ds.reset_state()
for (img, label) in ds:
from IPython import embed
embed()
break
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment