Commit 578439f1 authored by Yuxin Wu's avatar Yuxin Wu

update resnet-se

parent a9552311
...@@ -15,7 +15,7 @@ Models can be [downloaded here](https://goo.gl/6XjK9V). ...@@ -15,7 +15,7 @@ Models can be [downloaded here](https://goo.gl/6XjK9V).
| ResNet18 | 10.50% | 29.66% | | ResNet18 | 10.50% | 29.66% |
| ResNet34 | 8.56% | 26.17% | | ResNet34 | 8.56% | 26.17% |
| ResNet50 | 6.85% | 23.61% | | ResNet50 | 6.85% | 23.61% |
| ResNet50-SE | TRAINING | TRAINING | | ResNet50-SE | 6.24% | 22.64% |
| ResNet101 | 6.04% | 21.95% | | ResNet101 | 6.04% | 21.95% |
To train, just run: To train, just run:
......
...@@ -6,7 +6,6 @@ import sys ...@@ -6,7 +6,6 @@ import sys
import argparse import argparse
import numpy as np import numpy as np
import os import os
import multiprocessing
import tensorflow as tf import tensorflow as tf
...@@ -18,7 +17,8 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -18,7 +17,8 @@ from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import ( from imagenet_resnet_utils import (
fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone, fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone,
preresnet_group, eval_on_ILSVRC12, image_preprocess, compute_loss_and_error, resnet_group, eval_on_ILSVRC12, image_preprocess, compute_loss_and_error,
get_bn,
get_imagenet_dataflow) get_imagenet_dataflow)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
...@@ -45,17 +45,17 @@ class Model(ModelDesc): ...@@ -45,17 +45,17 @@ class Model(ModelDesc):
l, shortcut = apply_preactivation(l, preact) l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU) l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU) l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1) l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
squeeze = GlobalAvgPooling('gap', l) squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu) squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid) squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1]) l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride) return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
defs = RESNET_CONFIG[DEPTH] defs = RESNET_CONFIG[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'): with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
logits = resnet_backbone(image, defs, preresnet_group, bottleneck_se) logits = resnet_backbone(image, defs, resnet_group, bottleneck_se)
loss = compute_loss_and_error(logits, label) loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss') wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
...@@ -67,38 +67,44 @@ class Model(ModelDesc): ...@@ -67,38 +67,44 @@ class Model(ModelDesc):
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(name): def get_data(name, batch):
isTrain = name == 'train' isTrain = name == 'train'
datadir = args.data
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
datadir = args.data
return get_imagenet_dataflow( return get_imagenet_dataflow(
datadir, name, BATCH_SIZE, augmentors) datadir, name, batch, augmentors)
def get_config(): def get_config():
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
global BATCH_SIZE batch = TOTAL_BATCH_SIZE // nr_gpu
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, batch))
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, BATCH_SIZE))
dataset_train = get_data('train', batch)
dataset_train = get_data('train') dataset_val = get_data('val', batch)
dataset_val = get_data('val')
callbacks = [
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]),
HumanHyperParamSetter('learning_rate'),
]
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1:
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
else:
callbacks.append(DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))))
return TrainConfig( return TrainConfig(
model=Model(), model=Model(),
dataflow=dataset_train, dataflow=dataset_train,
callbacks=[ callbacks=callbacks,
ModelSaver(),
InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]),
],
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=110, max_epoch=110,
nr_tower=nr_gpu nr_tower=nr_tower
) )
...@@ -117,15 +123,14 @@ if __name__ == '__main__': ...@@ -117,15 +123,14 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval: if args.eval:
BATCH_SIZE = 128 # something that can run on one gpu ds = get_data('val', 128)
ds = get_data('val')
eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds) eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds)
sys.exit() sys.exit()
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-se-d' + str(DEPTH))) os.path.join('train_log', 'imagenet-resnet-se-d' + str(DEPTH)))
config = get_config() config = get_config(Model())
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SyncMultiGPUTrainerParameterServer(config).train() SyncMultiGPUTrainerParameterServer(config).train()
...@@ -43,7 +43,8 @@ class Model(ModelDesc): ...@@ -43,7 +43,8 @@ class Model(ModelDesc):
18: ([2, 2, 2, 2], basicblock), 18: ([2, 2, 2, 2], basicblock),
34: ([3, 4, 6, 3], basicblock), 34: ([3, 4, 6, 3], basicblock),
50: ([3, 4, 6, 3], bottleneck), 50: ([3, 4, 6, 3], bottleneck),
101: ([3, 4, 23, 3], bottleneck) 101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck)
}[depth] }[depth]
def _get_inputs(self): def _get_inputs(self):
...@@ -131,7 +132,7 @@ if __name__ == '__main__': ...@@ -131,7 +132,7 @@ if __name__ == '__main__':
parser.add_argument('--data_format', help='specify NCHW or NHWC', parser.add_argument('--data_format', help='specify NCHW or NHWC',
type=str, default='NCHW') type=str, default='NCHW')
parser.add_argument('-d', '--depth', help='resnet depth', parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101]) type=int, default=18, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true') parser.add_argument('--eval', action='store_true')
parser.add_argument('--preact', action='store_true', help='Use pre-activation resnet') parser.add_argument('--preact', action='store_true', help='Use pre-activation resnet')
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment