Commit a1477a75 authored by Yuxin Wu's avatar Yuxin Wu

Add "dir_structure" option to ILSVRCMeta

parent 8901482c
...@@ -52,7 +52,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn): ...@@ -52,7 +52,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
while not self.stopped(): while not self.stopped():
try: try:
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
# print "Score, ", score # print("Score, ", score)
except RuntimeError: except RuntimeError:
return return
self.queue_put_stoppable(self.q, score) self.queue_put_stoppable(self.q, score)
......
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import tarfile import tarfile
import six import six
import numpy as np import numpy as np
import tqdm
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from ...utils import logger from ...utils import logger
...@@ -58,22 +59,33 @@ class ILSVRCMeta(object): ...@@ -58,22 +59,33 @@ class ILSVRCMeta(object):
fpath = download(CAFFE_ILSVRC12_URL, self.dir) fpath = download(CAFFE_ILSVRC12_URL, self.dir)
tarfile.open(fpath, 'r:gz').extractall(self.dir) tarfile.open(fpath, 'r:gz').extractall(self.dir)
def get_image_list(self, name): def get_image_list(self, name, dir_structure='original'):
""" """
Args: Args:
name (str): 'train' or 'val' or 'test' name (str): 'train' or 'val' or 'test'
dir_structure (str): same as in :meth:`ILSVRC12.__init__()`.
Returns: Returns:
list: list of (image filename, label) list: list of (image filename, label)
""" """
assert name in ['train', 'val', 'test'] assert name in ['train', 'val', 'test']
assert dir_structure in ['original', 'train']
add_label_to_fname = (name != 'train' and dir_structure != 'original')
if add_label_to_fname:
synset = self.get_synset_1000()
fname = os.path.join(self.dir, name + '.txt') fname = os.path.join(self.dir, name + '.txt')
assert os.path.isfile(fname) assert os.path.isfile(fname), fname
with open(fname) as f: with open(fname) as f:
ret = [] ret = []
for line in f.readlines(): for line in f.readlines():
name, cls = line.strip().split() name, cls = line.strip().split()
ret.append((name, int(cls))) cls = int(cls)
assert len(ret)
if add_label_to_fname:
name = os.path.join(synset[cls], name)
ret.append((name.strip(), cls))
assert len(ret), fname
return ret return ret
def get_per_pixel_mean(self, size=None): def get_per_pixel_mean(self, size=None):
...@@ -109,8 +121,8 @@ class ILSVRC12(RNGDataFlow): ...@@ -109,8 +121,8 @@ class ILSVRC12(RNGDataFlow):
name (str): 'train' or 'val' or 'test'. name (str): 'train' or 'val' or 'test'.
shuffle (bool): shuffle the dataset. shuffle (bool): shuffle the dataset.
Defaults to True if name=='train'. Defaults to True if name=='train'.
dir_structure (str): The dir structure of 'val' and 'test' directory. dir_structure (str): The directory structure of 'val' and 'test' directory.
If is 'original', it expects the original decompressed 'original' means the original decompressed
directory, which only has list of image files (as below). directory, which only has list of image files (as below).
If set to 'train', it expects the same two-level If set to 'train', it expects the same two-level
directory structure simlar to 'train/'. directory structure simlar to 'train/'.
...@@ -135,7 +147,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -135,7 +147,7 @@ class ILSVRC12(RNGDataFlow):
ILSVRC2012_test_00000001.JPEG ILSVRC2012_test_00000001.JPEG
... ...
With ILSVRC12_img_*.tar, you can use the following With the downloaded ILSVRC12_img_*.tar, you can use the following
command to build the above structure: command to build the above structure:
.. code-block:: none .. code-block:: none
...@@ -154,8 +166,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -154,8 +166,7 @@ class ILSVRC12(RNGDataFlow):
shuffle = name == 'train' shuffle = name == 'train'
self.shuffle = shuffle self.shuffle = shuffle
meta = ILSVRCMeta(meta_dir) meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name) self.imglist = meta.get_image_list(name, dir_structure)
self.dir_structure = dir_structure
self.synset = meta.get_synset_1000() self.synset = meta.get_synset_1000()
if include_bb: if include_bb:
...@@ -170,16 +181,13 @@ class ILSVRC12(RNGDataFlow): ...@@ -170,16 +181,13 @@ class ILSVRC12(RNGDataFlow):
def get_data(self): def get_data(self):
idxs = np.arange(len(self.imglist)) idxs = np.arange(len(self.imglist))
add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original')
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
fname, label = self.imglist[k] fname, label = self.imglist[k]
if add_label_to_fname:
fname = os.path.join(self.full_dir, self.synset[label], fname)
else:
fname = os.path.join(self.full_dir, fname) fname = os.path.join(self.full_dir, fname)
im = cv2.imread(fname.strip(), cv2.IMREAD_COLOR)
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
if im.ndim == 2: if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3, 2) im = np.expand_dims(im, 2).repeat(3, 2)
...@@ -210,7 +218,6 @@ class ILSVRC12(RNGDataFlow): ...@@ -210,7 +218,6 @@ class ILSVRC12(RNGDataFlow):
with timed_operation('Loading Bounding Boxes ...'): with timed_operation('Loading Bounding Boxes ...'):
cnt = 0 cnt = 0
import tqdm
for k in tqdm.trange(len(imglist)): for k in tqdm.trange(len(imglist)):
fname = imglist[k][0] fname = imglist[k][0]
fname = fname[:-4] + 'xml' fname = fname[:-4] + 'xml'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment