Commit 15260844 authored by Yuxin Wu's avatar Yuxin Wu

merge resnet & resnet-se

parent de578d69
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: imagenet-resnet-se.py
import sys
import argparse
import numpy as np
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import (
fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone,
resnet_group, eval_on_ILSVRC12, image_preprocess, compute_loss_and_error,
get_bn,
get_imagenet_dataflow)
TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
DEPTH = None
RESNET_CONFIG = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
}
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=True)
image = tf.transpose(image, [0, 3, 1, 2])
def bottleneck_se(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
defs = RESNET_CONFIG[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
logits = resnet_backbone(image, defs, resnet_group, bottleneck_se)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 0.1, summary=True)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(name, batch):
isTrain = name == 'train'
augmentors = fbresnet_augmentor(isTrain)
datadir = args.data
return get_imagenet_dataflow(
datadir, name, batch, augmentors)
def get_config():
assert tf.test.is_gpu_available()
nr_gpu = get_nr_gpu()
batch = TOTAL_BATCH_SIZE // nr_gpu
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, batch))
dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch)
callbacks = [
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]),
HumanHyperParamSetter('learning_rate'),
]
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1:
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
else:
callbacks.append(DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))))
return TrainConfig(
model=Model(),
dataflow=dataset_train,
callbacks=callbacks,
steps_per_epoch=5000,
max_epoch=110,
nr_tower=nr_tower
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=50, choices=[50, 101])
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
DEPTH = args.depth
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval:
ds = get_data('val', 128)
eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds)
sys.exit()
logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-se-d' + str(DEPTH)))
config = get_config(Model())
if args.load:
config.session_init = SaverRestore(args.load)
SyncMultiGPUTrainerParameterServer(config).train()
......@@ -9,36 +9,34 @@ import os
import tensorflow as tf
from tensorpack import InputDesc, ModelDesc, logger, QueueInput
from tensorpack import logger, QueueInput
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.train import TrainConfig, SyncMultiGPUTrainerParameterServer
from tensorpack.dataflow import imgaug, FakeData
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import (
fbresnet_augmentor, get_imagenet_dataflow,
preresnet_group, preresnet_basicblock, preresnet_bottleneck,
resnet_group, resnet_basicblock, resnet_bottleneck,
resnet_backbone,
eval_on_ILSVRC12, image_preprocess, compute_loss_and_error)
resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck,
resnet_backbone, ImageNetModel,
eval_on_ILSVRC12)
TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
class Model(ModelDesc):
def __init__(self, depth, data_format='NCHW', preact=False):
if data_format == 'NCHW':
assert tf.test.is_gpu_available()
self.data_format = data_format
self.preact = preact
class Model(ImageNetModel):
def __init__(self, depth, data_format='NCHW', mode='resnet'):
super(Model, self).__init__(data_format)
basicblock = preresnet_basicblock if preact else resnet_basicblock
bottleneck = preresnet_bottleneck if preact else resnet_bottleneck
self.mode = mode
basicblock = preresnet_basicblock if mode == 'preact' else resnet_basicblock
bottleneck = {
'resnet': resnet_bottleneck,
'preact': preresnet_bottleneck,
'se': se_resnet_bottleneck}[mode]
self.num_blocks, self.block_func = {
18: ([2, 2, 2, 2], basicblock),
34: ([3, 4, 6, 3], basicblock),
......@@ -47,34 +45,11 @@ class Model(ModelDesc):
152: ([3, 8, 36, 3], bottleneck)
}[depth]
def _get_inputs(self):
# uint8 instead of float32 is used as input type to reduce copy overhead.
# It might hurt the performance a liiiitle bit.
# The pretrained models were trained with float32.
return [InputDesc(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=True)
if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2])
def get_logits(self, image):
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
logits = resnet_backbone(
return resnet_backbone(
image, self.num_blocks,
preresnet_group if self.preact else resnet_group, self.block_func)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 0.1, summary=True)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
preresnet_group if self.mode == 'preact' else resnet_group, self.block_func)
def get_data(name, batch):
......@@ -134,13 +109,17 @@ if __name__ == '__main__':
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true')
parser.add_argument('--preact', action='store_true', help='Use pre-activation resnet')
parser.add_argument('--mode', choices=['resnet', 'preact', 'se'],
help='variants of resnet to use', default='resnet')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
model = Model(args.depth, args.data_format, args.preact)
if args.mode == 'se':
assert args.depth >= 50
model = Model(args.depth, args.data_format, args.mode)
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_data('val', batch)
......
......@@ -5,11 +5,12 @@
import numpy as np
import cv2
import multiprocessing
from abc import abstractmethod
import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import imgaug, dataset
from tensorpack import imgaug, dataset, ModelDesc, InputDesc
from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ,
BatchData, ThreadedMapData)
......@@ -17,8 +18,8 @@ from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.models import (
Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU,
LinearWrap)
Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU, FullyConnected,
LinearWrap, regularize_cost)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
......@@ -120,11 +121,8 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
def apply_preactivation(l, preact):
"""
'no_preact' for the first resblock in each group only, because the input is activated already.
'bnrelu' for all the non-first blocks, where identity mapping is preserved on shortcut path.
"""
if preact == 'bnrelu':
# this is used only for preact-resnet
shortcut = l # preserve identity mapping
l = BNReLU('preact', l)
else:
......@@ -186,13 +184,25 @@ def resnet_bottleneck(l, ch_out, stride, preact):
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def se_resnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def resnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features,
stride if i == 0 else 1,
'no_preact')
stride if i == 0 else 1, 'no_preact')
# end of each block need an activation
l = tf.nn.relu(l)
return l
......@@ -262,3 +272,46 @@ def compute_loss_and_error(logits, label):
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
return loss
class ImageNetModel(ModelDesc):
def __init__(self, data_format='NCHW', image_dtype=tf.uint8):
if data_format == 'NCHW':
assert tf.test.is_gpu_available()
self.data_format = data_format
# uint8 instead of float32 is used as input type to reduce copy overhead.
# It might hurt the performance a liiiitle bit.
# The pretrained models were trained with float32.
self.image_dtype = image_dtype
def _get_inputs(self):
return [InputDesc(self.image_dtype, [None, 224, 224, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=True)
if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2])
logits = self.get_logits(image)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', tf.contrib.layers.l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
@abstractmethod
def get_logits(self, image):
"""
Args:
image: 4D tensor of 224x224 in ``self.data_format``
Returns:
Bx1000 logits
"""
def _get_optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
tf.summary.scalar('learning_rate-summary', lr)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment