Commit ec1bea93 authored by Yuxin Wu's avatar Yuxin Wu

add SE-ResNet

parent 54785d5c
## imagenet-resnet.py
## imagenet-resnet.py, imagenet-resnet-se.py
__Training__ code of pre-activation ResNet on ImageNet. It follows the setup in
[fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) (except for the weight decay) and gets similar performance (with much fewer lines of code).
__Training__ code of ResNet on ImageNet, with pre-activation and squeeze-and-excitation.
The pre-act ResNet follows the setup in [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) (except for the weight decay)
and gets similar performance (with much fewer lines of code).
Models can be [downloaded here](https://goo.gl/6XjK9V).
| Model | Top 5 Error | Top 1 Error |
......@@ -16,7 +17,7 @@ To train, just run:
```bash
./imagenet-resnet.py --data /path/to/original/ILSVRC --gpu 0,1,2,3 -d 18
```
The speed is 1310 image/s on 4 Tesla M40, if your data is fast enough.
You should be able to see good GPU utilization (around 95%), if your data is fast enough.
See the [tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html) on how to speed up your data.
![imagenet](imagenet-resnet.png)
......@@ -46,9 +47,6 @@ The per-pixel mean used here is slightly different from the original.
Reproduce pre-activation ResNet on CIFAR10.
The train error shown here is a moving average of the error rate of each batch in training.
The validation error here is computed on test set.
![cifar10](cifar10-resnet.png)
Also see a [DenseNet implementation](https://github.com/YixuanLi/densenet-tensorflow) of the paper [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993).
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: imagenet-resnet-se.py
import sys
import argparse
import numpy as np
import os
import multiprocessing
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from imagenet_resnet_utils import (
fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone,
eval_on_ILSVRC12, image_preprocess, compute_loss_and_error)
TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
DEPTH = None
RESNET_CONFIG = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
}
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=True)
image = tf.transpose(image, [0, 3, 1, 2])
def bottleneck_se(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1)
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.identity)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride)
defs = RESNET_CONFIG[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
logits = resnet_backbone(image, defs, bottleneck_se)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 0.1, summary=True)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(train_or_test):
isTrain = train_or_test == 'train'
datadir = args.data
ds = dataset.ILSVRC12(datadir, train_or_test,
shuffle=isTrain, dir_structure='train')
augmentors = fbresnet_augmentor(isTrain)
ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain:
ds = PrefetchDataZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds
def get_config():
assert tf.test.is_gpu_available()
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, BATCH_SIZE))
dataset_train = get_data('train')
dataset_val = get_data('val')
return TrainConfig(
model=Model(),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
],
steps_per_epoch=5000,
max_epoch=110,
nr_tower=nr_gpu
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101])
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
DEPTH = args.depth
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval:
BATCH_SIZE = 64 # something that can run on one gpu
ds = get_data('val')
eval_on_ILSVRC12(Model(), args.load, ds)
sys.exit()
logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-se-d' + str(DEPTH)))
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SyncMultiGPUTrainerParameterServer(config).train()
......@@ -84,6 +84,10 @@ def get_data(train_or_test):
def get_config(fake=False, data_format='NCHW'):
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, BATCH_SIZE))
if fake:
dataset_train = dataset_val = FakeData(
[[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
......@@ -105,12 +109,13 @@ def get_config(fake=False, data_format='NCHW'):
],
steps_per_epoch=5000,
max_epoch=110,
nr_tower=nr_gpu
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True)
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
......@@ -131,14 +136,9 @@ if __name__ == '__main__':
eval_on_ILSVRC12(Model(), args.load, ds)
sys.exit()
NR_GPU = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-d' + str(DEPTH)))
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(NR_GPU, BATCH_SIZE))
config = get_config(fake=args.fake, data_format=args.data_format)
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU
SyncMultiGPUTrainerParameterServer(config).train()
......@@ -78,8 +78,7 @@ def get_data(train_or_test):
datadir = args.data
ds = dataset.ILSVRC12(datadir, train_or_test,
shuffle=True if isTrain else False,
dir_structure='original')
shuffle=isTrain, dir_structure='original')
augmentors = fbresnet_augmentor(isTrain)
augmentors.append(imgaug.ToUint8())
......@@ -91,6 +90,9 @@ def get_data(train_or_test):
def get_config():
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
dataset_train = get_data('train')
dataset_val = get_data('val')
......@@ -107,6 +109,7 @@ def get_config():
],
steps_per_epoch=5000,
max_epoch=110,
nr_tower=nr_gpu
)
......@@ -162,12 +165,8 @@ if __name__ == '__main__':
viz_cam(args.load, args.data)
sys.exit()
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.auto_set_dir()
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
config.nr_tower = nr_gpu
SyncMultiGPUTrainer(config).train()
......@@ -5,6 +5,7 @@
import six
import os
import pprint
import tensorflow as tf
from collections import defaultdict
import numpy as np
......@@ -127,7 +128,8 @@ def dump_session_params(path):
for v in var:
result[v.name] = v.eval()
logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys()))
keys = sorted(list(result.keys()))
logger.info(pprint.pformat(keys))
np.save(path, result)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment