Commit ec1bea93 authored by Yuxin Wu's avatar Yuxin Wu

add SE-ResNet

parent 54785d5c
## imagenet-resnet.py ## imagenet-resnet.py, imagenet-resnet-se.py
__Training__ code of pre-activation ResNet on ImageNet. It follows the setup in __Training__ code of ResNet on ImageNet, with pre-activation and squeeze-and-excitation.
[fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) (except for the weight decay) and gets similar performance (with much fewer lines of code). The pre-act ResNet follows the setup in [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) (except for the weight decay)
and gets similar performance (with much fewer lines of code).
Models can be [downloaded here](https://goo.gl/6XjK9V). Models can be [downloaded here](https://goo.gl/6XjK9V).
| Model | Top 5 Error | Top 1 Error | | Model | Top 5 Error | Top 1 Error |
...@@ -16,7 +17,7 @@ To train, just run: ...@@ -16,7 +17,7 @@ To train, just run:
```bash ```bash
./imagenet-resnet.py --data /path/to/original/ILSVRC --gpu 0,1,2,3 -d 18 ./imagenet-resnet.py --data /path/to/original/ILSVRC --gpu 0,1,2,3 -d 18
``` ```
The speed is 1310 image/s on 4 Tesla M40, if your data is fast enough. You should be able to see good GPU utilization (around 95%), if your data is fast enough.
See the [tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html) on how to speed up your data. See the [tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html) on how to speed up your data.
![imagenet](imagenet-resnet.png) ![imagenet](imagenet-resnet.png)
...@@ -46,9 +47,6 @@ The per-pixel mean used here is slightly different from the original. ...@@ -46,9 +47,6 @@ The per-pixel mean used here is slightly different from the original.
Reproduce pre-activation ResNet on CIFAR10. Reproduce pre-activation ResNet on CIFAR10.
The train error shown here is a moving average of the error rate of each batch in training.
The validation error here is computed on test set.
![cifar10](cifar10-resnet.png) ![cifar10](cifar10-resnet.png)
Also see a [DenseNet implementation](https://github.com/YixuanLi/densenet-tensorflow) of the paper [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993). Also see a [DenseNet implementation](https://github.com/YixuanLi/densenet-tensorflow) of the paper [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993).
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: imagenet-resnet-se.py
import sys
import argparse
import numpy as np
import os
import multiprocessing
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from imagenet_resnet_utils import (
fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone,
eval_on_ILSVRC12, image_preprocess, compute_loss_and_error)
TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
DEPTH = None
RESNET_CONFIG = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
}
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=True)
image = tf.transpose(image, [0, 3, 1, 2])
def bottleneck_se(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1)
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.identity)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride)
defs = RESNET_CONFIG[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
logits = resnet_backbone(image, defs, bottleneck_se)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 0.1, summary=True)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(train_or_test):
isTrain = train_or_test == 'train'
datadir = args.data
ds = dataset.ILSVRC12(datadir, train_or_test,
shuffle=isTrain, dir_structure='train')
augmentors = fbresnet_augmentor(isTrain)
ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain:
ds = PrefetchDataZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds
def get_config():
assert tf.test.is_gpu_available()
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, BATCH_SIZE))
dataset_train = get_data('train')
dataset_val = get_data('val')
return TrainConfig(
model=Model(),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
],
steps_per_epoch=5000,
max_epoch=110,
nr_tower=nr_gpu
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101])
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
DEPTH = args.depth
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval:
BATCH_SIZE = 64 # something that can run on one gpu
ds = get_data('val')
eval_on_ILSVRC12(Model(), args.load, ds)
sys.exit()
logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-se-d' + str(DEPTH)))
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SyncMultiGPUTrainerParameterServer(config).train()
...@@ -84,6 +84,10 @@ def get_data(train_or_test): ...@@ -84,6 +84,10 @@ def get_data(train_or_test):
def get_config(fake=False, data_format='NCHW'): def get_config(fake=False, data_format='NCHW'):
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, BATCH_SIZE))
if fake: if fake:
dataset_train = dataset_val = FakeData( dataset_train = dataset_val = FakeData(
[[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8') [[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
...@@ -105,12 +109,13 @@ def get_config(fake=False, data_format='NCHW'): ...@@ -105,12 +109,13 @@ def get_config(fake=False, data_format='NCHW'):
], ],
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=110, max_epoch=110,
nr_tower=nr_gpu
) )
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True) parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true') parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
...@@ -131,14 +136,9 @@ if __name__ == '__main__': ...@@ -131,14 +136,9 @@ if __name__ == '__main__':
eval_on_ILSVRC12(Model(), args.load, ds) eval_on_ILSVRC12(Model(), args.load, ds)
sys.exit() sys.exit()
NR_GPU = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-d' + str(DEPTH))) os.path.join('train_log', 'imagenet-resnet-d' + str(DEPTH)))
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(NR_GPU, BATCH_SIZE))
config = get_config(fake=args.fake, data_format=args.data_format) config = get_config(fake=args.fake, data_format=args.data_format)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU
SyncMultiGPUTrainerParameterServer(config).train() SyncMultiGPUTrainerParameterServer(config).train()
...@@ -78,8 +78,7 @@ def get_data(train_or_test): ...@@ -78,8 +78,7 @@ def get_data(train_or_test):
datadir = args.data datadir = args.data
ds = dataset.ILSVRC12(datadir, train_or_test, ds = dataset.ILSVRC12(datadir, train_or_test,
shuffle=True if isTrain else False, shuffle=isTrain, dir_structure='original')
dir_structure='original')
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
augmentors.append(imgaug.ToUint8()) augmentors.append(imgaug.ToUint8())
...@@ -91,6 +90,9 @@ def get_data(train_or_test): ...@@ -91,6 +90,9 @@ def get_data(train_or_test):
def get_config(): def get_config():
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
dataset_train = get_data('train') dataset_train = get_data('train')
dataset_val = get_data('val') dataset_val = get_data('val')
...@@ -107,6 +109,7 @@ def get_config(): ...@@ -107,6 +109,7 @@ def get_config():
], ],
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=110, max_epoch=110,
nr_tower=nr_gpu
) )
...@@ -162,12 +165,8 @@ if __name__ == '__main__': ...@@ -162,12 +165,8 @@ if __name__ == '__main__':
viz_cam(args.load, args.data) viz_cam(args.load, args.data)
sys.exit() sys.exit()
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.auto_set_dir() logger.auto_set_dir()
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
config.nr_tower = nr_gpu
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import six import six
import os import os
import pprint
import tensorflow as tf import tensorflow as tf
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
...@@ -127,7 +128,8 @@ def dump_session_params(path): ...@@ -127,7 +128,8 @@ def dump_session_params(path):
for v in var: for v in var:
result[v.name] = v.eval() result[v.name] = v.eval()
logger.info("Variables to save to {}:".format(path)) logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys())) keys = sorted(list(result.keys()))
logger.info(pprint.pformat(keys))
np.save(path, result) np.save(path, result)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment