Commit f3c50d39 authored by Yuxin Wu's avatar Yuxin Wu

use imagenet_utils for inception-bn

parent 33304beb
../ResNet/imagenet_utils.py
\ No newline at end of file
...@@ -14,7 +14,9 @@ from tensorpack import * ...@@ -14,7 +14,9 @@ from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_utils import fbresnet_augmentor, get_imagenet_dataflow
TOTAL_BATCH_SIZE = 64 * 6 TOTAL_BATCH_SIZE = 64 * 6
NR_GPU = 6 NR_GPU = 6
...@@ -25,8 +27,7 @@ INPUT_SHAPE = 224 ...@@ -25,8 +27,7 @@ INPUT_SHAPE = 224
Inception-BN model on ILSVRC12. Inception-BN model on ILSVRC12.
See "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift", arxiv:1502.03167 See "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift", arxiv:1502.03167
This config reaches 71% single-crop validation accuracy after 150k steps with 6 TitanX. This config reaches 73% single-crop validation accuracy after 300k steps with 6 GPUs.
Learning rate may need a different schedule for different number of GPUs (because batch size will be different).
""" """
...@@ -129,30 +130,13 @@ class Model(ModelDesc): ...@@ -129,30 +130,13 @@ class Model(ModelDesc):
def get_data(train_or_test): def get_data(train_or_test):
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
ds = dataset.ILSVRC12(args.data, train_or_test, shuffle=True if isTrain else False) augs = fbresnet_augmentor(isTrain)
meta = dataset.ILSVRCMeta() meta = dataset.ILSVRCMeta()
pp_mean = meta.get_per_pixel_mean() pp_mean = meta.get_per_pixel_mean()
augs.append(imgaug.MapImage(lambda x: x - pp_mean[16:-16, 16:-16]))
if isTrain: ds = get_imagenet_dataflow(args.data, train_or_test, BATCH_SIZE, augs)
# TODO use the augmentor in GoogleNet
augmentors = [
imgaug.Resize((256, 256)),
imgaug.Brightness(30, False),
imgaug.Contrast((0.8, 1.2), True),
imgaug.MapImage(lambda x: x - pp_mean),
imgaug.RandomCrop((224, 224)),
imgaug.Flip(horiz=True),
]
else:
augmentors = [
imgaug.Resize((256, 256)),
imgaug.MapImage(lambda x: x - pp_mean),
imgaug.CenterCrop((224, 224)),
]
ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchDataZMQ(ds, 6)
return ds return ds
...@@ -192,7 +176,6 @@ if __name__ == '__main__': ...@@ -192,7 +176,6 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: nr_tower = get_nr_gpu()
nr_tower = len(args.gpu.split(','))
assert nr_tower == NR_GPU assert nr_tower == NR_GPU
launch_train_with_config(config, SyncMultiGPUTrainer(NR_GPU)) launch_train_with_config(config, SyncMultiGPUTrainer(NR_GPU))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment