Commit 36bdc187 authored by Yuxin Wu's avatar Yuxin Wu

add places365 dataset

parent ecf525d6
...@@ -192,8 +192,7 @@ class ILSVRC12(ILSVRC12Files): ...@@ -192,8 +192,7 @@ class ILSVRC12(ILSVRC12Files):
dir (str): A directory containing a subdir named ``name``, dir (str): A directory containing a subdir named ``name``,
containing the images in a structure described below. containing the images in a structure described below.
name (str): One of 'train' or 'val' or 'test'. name (str): One of 'train' or 'val' or 'test'.
shuffle (bool): shuffle the dataset. shuffle (bool): shuffle the dataset. Defaults to True if name=='train'.
Defaults to True if name=='train'.
dir_structure (str): One of 'original' or 'train'. dir_structure (str): One of 'original' or 'train'.
The directory structure for the 'val' directory. The directory structure for the 'val' directory.
'original' means the original decompressed directory, which only has list of image files (as below). 'original' means the original decompressed directory, which only has list of image files (as below).
...@@ -354,6 +353,7 @@ try: ...@@ -354,6 +353,7 @@ try:
except ImportError: except ImportError:
from ...utils.develop import create_dummy_class from ...utils.develop import create_dummy_class
ILSVRC12 = create_dummy_class('ILSVRC12', 'cv2') # noqa ILSVRC12 = create_dummy_class('ILSVRC12', 'cv2') # noqa
TinyImageNet = create_dummy_class('TinyImageNet', 'cv2') # noqa
if __name__ == '__main__': if __name__ == '__main__':
meta = ILSVRCMeta() meta = ILSVRCMeta()
......
#-*- coding: utf-8 -*-
import os
import numpy as np
from ...utils import logger
from ..base import RNGDataFlow
class Places365Standard(RNGDataFlow):
"""
The Places365-Standard Dataset, in low resolution format only.
Produces BGR images of shape (256, 256, 3) in range [0, 255].
"""
def __init__(self, dir, name, shuffle=None):
"""
Args:
dir: path to the Places365-Standard dataset in its "easy directory
structure". See http://places2.csail.mit.edu/download.html
name: one of "train" or "val"
shuffle (bool): shuffle the dataset. Defaults to True if name=='train'.
"""
assert name in ['train', 'val'], name
dir = os.path.expanduser(dir)
assert os.path.isdir(dir), dir
self.name = name
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle
label_file = os.path.join(dir, name + ".txt")
all_files = []
labels = set()
with open(label_file) as f:
for line in f:
filepath = os.path.join(dir, line.strip())
line = line.strip().split("/")
label = line[1]
all_files.append((filepath, label))
labels.add(label)
self._labels = sorted(list(labels))
# class ids are sorted alphabetically:
# https://github.com/CSAILVision/places365/blob/master/categories_places365.txt
labelmap = {label: id for id, label in enumerate(self._labels)}
self._files = [(path, labelmap[x]) for path, x in all_files]
logger.info("Found {} images in {}.".format(len(self._files), label_file))
def get_label_names(self):
"""
Returns:
[str]: name of each class.
"""
return self._labels
def __len__(self):
return len(self._files)
def __iter__(self):
idxs = np.arange(len(self._files))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self._files[k]
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
yield [im, label]
try:
import cv2
except ImportError:
from ...utils.develop import create_dummy_class
Places365Standard = create_dummy_class('Places365Standard', 'cv2') # noqa
if __name__ == '__main__':
from tensorpack.dataflow import PrintData
ds = Places365Standard("~/data/places365_standard/", 'train')
ds = PrintData(ds, num=100)
ds.reset_state()
for k in ds:
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment