Commit 2b027291 authored by Yuxin Wu's avatar Yuxin Wu

auto detect dir_structure

parent 5e40fa3c
...@@ -202,7 +202,7 @@ def get_data(train_or_test): ...@@ -202,7 +202,7 @@ def get_data(train_or_test):
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
ds = dataset.ILSVRC12(args.data, train_or_test, ds = dataset.ILSVRC12(args.data, train_or_test,
shuffle=True if isTrain else False, dir_structure='train') shuffle=True if isTrain else False)
meta = dataset.ILSVRCMeta() meta = dataset.ILSVRCMeta()
pp_mean = meta.get_per_pixel_mean() pp_mean = meta.get_per_pixel_mean()
pp_mean_299 = cv2.resize(pp_mean, (299, 299)) pp_mean_299 = cv2.resize(pp_mean, (299, 299))
......
...@@ -72,7 +72,7 @@ def get_data(name): ...@@ -72,7 +72,7 @@ def get_data(name):
datadir = args.data datadir = args.data
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
return get_imagenet_dataflow( return get_imagenet_dataflow(
datadir, name, BATCH_SIZE, augmentors, dir_structure='original') datadir, name, BATCH_SIZE, augmentors)
def get_config(): def get_config():
......
...@@ -80,7 +80,7 @@ def get_data(name, batch): ...@@ -80,7 +80,7 @@ def get_data(name, batch):
isTrain = name == 'train' isTrain = name == 'train'
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
return get_imagenet_dataflow( return get_imagenet_dataflow(
args.data, name, batch, augmentors, dir_structure='original') args.data, name, batch, augmentors)
def get_config(model, fake=False): def get_config(model, fake=False):
......
...@@ -82,7 +82,7 @@ def fbresnet_augmentor(isTrain): ...@@ -82,7 +82,7 @@ def fbresnet_augmentor(isTrain):
def get_imagenet_dataflow( def get_imagenet_dataflow(
datadir, name, batch_size, datadir, name, batch_size,
augmentors, dir_structure='original'): augmentors):
""" """
See explanations in the tutorial: See explanations in the tutorial:
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
...@@ -97,8 +97,7 @@ def get_imagenet_dataflow( ...@@ -97,8 +97,7 @@ def get_imagenet_dataflow(
ds = PrefetchDataZMQ(ds, cpu) ds = PrefetchDataZMQ(ds, cpu)
ds = BatchData(ds, batch_size, remainder=False) ds = BatchData(ds, batch_size, remainder=False)
else: else:
ds = dataset.ILSVRC12Files(datadir, name, ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
shuffle=False, dir_structure=dir_structure)
aug = imgaug.AugmentorList(augmentors) aug = imgaug.AugmentorList(augmentors)
def mapf(dp): def mapf(dp):
......
...@@ -203,7 +203,7 @@ if __name__ == '__main__': ...@@ -203,7 +203,7 @@ if __name__ == '__main__':
resnet_param[newname] = v resnet_param[newname] = v
if args.eval: if args.eval:
ds = ILSVRC12(args.eval, 'val', shuffle=False, dir_structure='train') ds = ILSVRC12(args.eval, 'val', shuffle=False)
ds = AugmentImageComponent(ds, get_inference_augmentor()) ds = AugmentImageComponent(ds, get_inference_augmentor())
ds = BatchData(ds, 128, remainder=True) ds = BatchData(ds, 128, remainder=True)
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
......
...@@ -78,8 +78,7 @@ def get_data(train_or_test): ...@@ -78,8 +78,7 @@ def get_data(train_or_test):
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
datadir = args.data datadir = args.data
ds = dataset.ILSVRC12(datadir, train_or_test, ds = dataset.ILSVRC12(datadir, train_or_test, shuffle=isTrain)
shuffle=isTrain, dir_structure='original')
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
augmentors.append(imgaug.ToUint8()) augmentors.append(imgaug.ToUint8())
......
...@@ -105,6 +105,20 @@ class ILSVRCMeta(object): ...@@ -105,6 +105,20 @@ class ILSVRCMeta(object):
return arr return arr
def _guess_dir_structure(dir):
subdir = os.listdir(dir)[0]
# find a subdir starting with 'n'
if subdir.startswith('n') and \
os.path.isdir(os.path.join(dir, subdir)):
dir_structure = 'train'
else:
dir_structure = 'original'
logger.info(
"Assuming directory {} has {} structure.".format(
dir, dir_structure))
return dir_structure
class ILSVRC12Files(RNGDataFlow): class ILSVRC12Files(RNGDataFlow):
""" """
Same as :class:`ILSVRC12`, but produces filenames of the images instead of nparrays. Same as :class:`ILSVRC12`, but produces filenames of the images instead of nparrays.
...@@ -112,7 +126,7 @@ class ILSVRC12Files(RNGDataFlow): ...@@ -112,7 +126,7 @@ class ILSVRC12Files(RNGDataFlow):
decode it in smarter ways (e.g. in parallel). decode it in smarter ways (e.g. in parallel).
""" """
def __init__(self, dir, name, meta_dir=None, def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'): shuffle=None, dir_structure=None):
""" """
Same as in :class:`ILSVRC12`. Same as in :class:`ILSVRC12`.
""" """
...@@ -124,6 +138,12 @@ class ILSVRC12Files(RNGDataFlow): ...@@ -124,6 +138,12 @@ class ILSVRC12Files(RNGDataFlow):
if shuffle is None: if shuffle is None:
shuffle = name == 'train' shuffle = name == 'train'
self.shuffle = shuffle self.shuffle = shuffle
if name == 'train':
dir_structure = 'train'
if dir_structure is None:
dir_structure = _guess_dir_structure(self.full_dir)
meta = ILSVRCMeta(meta_dir) meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name, dir_structure) self.imglist = meta.get_image_list(name, dir_structure)
...@@ -149,7 +169,7 @@ class ILSVRC12(ILSVRC12Files): ...@@ -149,7 +169,7 @@ class ILSVRC12(ILSVRC12Files):
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999]. Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999].
""" """
def __init__(self, dir, name, meta_dir=None, def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'): shuffle=None, dir_structure=None):
""" """
Args: Args:
dir (str): A directory containing a subdir named ``name``, where the dir (str): A directory containing a subdir named ``name``, where the
...@@ -162,6 +182,7 @@ class ILSVRC12(ILSVRC12Files): ...@@ -162,6 +182,7 @@ class ILSVRC12(ILSVRC12Files):
directory, which only has list of image files (as below). directory, which only has list of image files (as below).
If set to 'train', it expects the same two-level If set to 'train', it expects the same two-level
directory structure simlar to 'train/'. directory structure simlar to 'train/'.
By default, it tries to automatically detect the structure.
Examples: Examples:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment