Commit 67dd4887 authored by Yuxin Wu's avatar Yuxin Wu

dir structure for ilsvrc

parent b2ac1dae
...@@ -44,6 +44,15 @@ class ILSVRCMeta(object): ...@@ -44,6 +44,15 @@ class ILSVRCMeta(object):
lines = [x.strip() for x in open(fname).readlines()] lines = [x.strip() for x in open(fname).readlines()]
return dict(enumerate(lines)) return dict(enumerate(lines))
def get_synset_1000(self):
"""
:returns a dict of {cls_number: synset_id}
"""
fname = os.path.join(self.dir, 'synsets.txt')
assert os.path.isfile(fname)
lines = [x.strip() for x in open(fname).readlines()]
return dict(enumerate(lines))
def _download_caffe_meta(self): def _download_caffe_meta(self):
fpath = download(CAFFE_ILSVRC12_URL, self.dir) fpath = download(CAFFE_ILSVRC12_URL, self.dir)
tarfile.open(fpath, 'r:gz').extractall(self.dir) tarfile.open(fpath, 'r:gz').extractall(self.dir)
...@@ -80,11 +89,16 @@ class ILSVRCMeta(object): ...@@ -80,11 +89,16 @@ class ILSVRCMeta(object):
return arr return arr
class ILSVRC12(RNGDataFlow): class ILSVRC12(RNGDataFlow):
def __init__(self, dir, name, meta_dir=None, shuffle=True): def __init__(self, dir, name, meta_dir=None, shuffle=True,
dir_structure='original'):
""" """
:param dir: A directory containing a subdir named `name`, where the :param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed. original ILSVRC12_`name`.tar gets decompressed.
:param name: 'train' or 'val' or 'test' :param name: 'train' or 'val' or 'test'
:param dir_structure: the dir structure of 'val' or 'test'.
if is 'original' then keep the original decompressed dir with list
of image files. if equals to 'train', use the `train/` dir
structure with class name as subdirectories.
Dir should have the following structure: Dir should have the following structure:
...@@ -114,10 +128,13 @@ class ILSVRC12(RNGDataFlow): ...@@ -114,10 +128,13 @@ class ILSVRC12(RNGDataFlow):
""" """
assert name in ['train', 'test', 'val'] assert name in ['train', 'test', 'val']
self.full_dir = os.path.join(dir, name) self.full_dir = os.path.join(dir, name)
self.name = name
assert os.path.isdir(self.full_dir), self.full_dir assert os.path.isdir(self.full_dir), self.full_dir
self.shuffle = shuffle self.shuffle = shuffle
self.meta = ILSVRCMeta(meta_dir) meta = ILSVRCMeta(meta_dir)
self.imglist = self.meta.get_image_list(name) self.imglist = meta.get_image_list(name)
self.dir_structure = dir_structure
self.synset = meta.get_synset_1000()
def size(self): def size(self):
return len(self.imglist) return len(self.imglist)
...@@ -127,16 +144,20 @@ class ILSVRC12(RNGDataFlow): ...@@ -127,16 +144,20 @@ class ILSVRC12(RNGDataFlow):
Produce original images or shape [h, w, 3], and label Produce original images or shape [h, w, 3], and label
""" """
idxs = np.arange(len(self.imglist)) idxs = np.arange(len(self.imglist))
isTrain = self.name == 'train'
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
tp = self.imglist[k] fname, label = self.imglist[k]
fname = os.path.join(self.full_dir, tp[0]).strip() if not isTrain and self.dir_structure != 'original':
im = cv2.imread(fname, cv2.IMREAD_COLOR) fname = os.path.join(self.full_dir, self.synset[label], fname)
else:
fname = os.path.join(self.full_dir, fname)
im = cv2.imread(fname.strip(), cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
if im.ndim == 2: if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3,2) im = np.expand_dims(im, 2).repeat(3,2)
yield [im, tp[1]] yield [im, label]
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment