Commit 2b027291 authored by Yuxin Wu's avatar Yuxin Wu

auto detect dir_structure

parent 5e40fa3c
......@@ -202,7 +202,7 @@ def get_data(train_or_test):
isTrain = train_or_test == 'train'
ds = dataset.ILSVRC12(args.data, train_or_test,
shuffle=True if isTrain else False, dir_structure='train')
shuffle=True if isTrain else False)
meta = dataset.ILSVRCMeta()
pp_mean = meta.get_per_pixel_mean()
pp_mean_299 = cv2.resize(pp_mean, (299, 299))
......
......@@ -72,7 +72,7 @@ def get_data(name):
datadir = args.data
augmentors = fbresnet_augmentor(isTrain)
return get_imagenet_dataflow(
datadir, name, BATCH_SIZE, augmentors, dir_structure='original')
datadir, name, BATCH_SIZE, augmentors)
def get_config():
......
......@@ -80,7 +80,7 @@ def get_data(name, batch):
isTrain = name == 'train'
augmentors = fbresnet_augmentor(isTrain)
return get_imagenet_dataflow(
args.data, name, batch, augmentors, dir_structure='original')
args.data, name, batch, augmentors)
def get_config(model, fake=False):
......
......@@ -82,7 +82,7 @@ def fbresnet_augmentor(isTrain):
def get_imagenet_dataflow(
datadir, name, batch_size,
augmentors, dir_structure='original'):
augmentors):
"""
See explanations in the tutorial:
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
......@@ -97,8 +97,7 @@ def get_imagenet_dataflow(
ds = PrefetchDataZMQ(ds, cpu)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, name,
shuffle=False, dir_structure=dir_structure)
ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
......
......@@ -203,7 +203,7 @@ if __name__ == '__main__':
resnet_param[newname] = v
if args.eval:
ds = ILSVRC12(args.eval, 'val', shuffle=False, dir_structure='train')
ds = ILSVRC12(args.eval, 'val', shuffle=False)
ds = AugmentImageComponent(ds, get_inference_augmentor())
ds = BatchData(ds, 128, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
......
......@@ -78,8 +78,7 @@ def get_data(train_or_test):
isTrain = train_or_test == 'train'
datadir = args.data
ds = dataset.ILSVRC12(datadir, train_or_test,
shuffle=isTrain, dir_structure='original')
ds = dataset.ILSVRC12(datadir, train_or_test, shuffle=isTrain)
augmentors = fbresnet_augmentor(isTrain)
augmentors.append(imgaug.ToUint8())
......
......@@ -105,6 +105,20 @@ class ILSVRCMeta(object):
return arr
def _guess_dir_structure(dir):
subdir = os.listdir(dir)[0]
# find a subdir starting with 'n'
if subdir.startswith('n') and \
os.path.isdir(os.path.join(dir, subdir)):
dir_structure = 'train'
else:
dir_structure = 'original'
logger.info(
"Assuming directory {} has {} structure.".format(
dir, dir_structure))
return dir_structure
class ILSVRC12Files(RNGDataFlow):
"""
Same as :class:`ILSVRC12`, but produces filenames of the images instead of nparrays.
......@@ -112,7 +126,7 @@ class ILSVRC12Files(RNGDataFlow):
decode it in smarter ways (e.g. in parallel).
"""
def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'):
shuffle=None, dir_structure=None):
"""
Same as in :class:`ILSVRC12`.
"""
......@@ -124,6 +138,12 @@ class ILSVRC12Files(RNGDataFlow):
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle
if name == 'train':
dir_structure = 'train'
if dir_structure is None:
dir_structure = _guess_dir_structure(self.full_dir)
meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name, dir_structure)
......@@ -149,7 +169,7 @@ class ILSVRC12(ILSVRC12Files):
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999].
"""
def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'):
shuffle=None, dir_structure=None):
"""
Args:
dir (str): A directory containing a subdir named ``name``, where the
......@@ -162,6 +182,7 @@ class ILSVRC12(ILSVRC12Files):
directory, which only has list of image files (as below).
If set to 'train', it expects the same two-level
directory structure simlar to 'train/'.
By default, it tries to automatically detect the structure.
Examples:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment