Commit 15260844 authored by Yuxin Wu's avatar Yuxin Wu

merge resnet & resnet-se

parent de578d69
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: imagenet-resnet-se.py
import sys
import argparse
import numpy as np
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import (
fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone,
resnet_group, eval_on_ILSVRC12, image_preprocess, compute_loss_and_error,
get_bn,
get_imagenet_dataflow)
TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
DEPTH = None
RESNET_CONFIG = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
}
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=True)
image = tf.transpose(image, [0, 3, 1, 2])
def bottleneck_se(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
defs = RESNET_CONFIG[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
logits = resnet_backbone(image, defs, resnet_group, bottleneck_se)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 0.1, summary=True)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(name, batch):
isTrain = name == 'train'
augmentors = fbresnet_augmentor(isTrain)
datadir = args.data
return get_imagenet_dataflow(
datadir, name, batch, augmentors)
def get_config():
assert tf.test.is_gpu_available()
nr_gpu = get_nr_gpu()
batch = TOTAL_BATCH_SIZE // nr_gpu
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, batch))
dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch)
callbacks = [
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]),
HumanHyperParamSetter('learning_rate'),
]
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1:
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
else:
callbacks.append(DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))))
return TrainConfig(
model=Model(),
dataflow=dataset_train,
callbacks=callbacks,
steps_per_epoch=5000,
max_epoch=110,
nr_tower=nr_tower
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=50, choices=[50, 101])
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
DEPTH = args.depth
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval:
ds = get_data('val', 128)
eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds)
sys.exit()
logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-se-d' + str(DEPTH)))
config = get_config(Model())
if args.load:
config.session_init = SaverRestore(args.load)
SyncMultiGPUTrainerParameterServer(config).train()
...@@ -9,36 +9,34 @@ import os ...@@ -9,36 +9,34 @@ import os
import tensorflow as tf import tensorflow as tf
from tensorpack import InputDesc, ModelDesc, logger, QueueInput from tensorpack import logger, QueueInput
from tensorpack.models import * from tensorpack.models import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.train import TrainConfig, SyncMultiGPUTrainerParameterServer from tensorpack.train import TrainConfig, SyncMultiGPUTrainerParameterServer
from tensorpack.dataflow import imgaug, FakeData from tensorpack.dataflow import imgaug, FakeData
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils import argscope, get_model_loader from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import ( from imagenet_resnet_utils import (
fbresnet_augmentor, get_imagenet_dataflow, fbresnet_augmentor, get_imagenet_dataflow,
preresnet_group, preresnet_basicblock, preresnet_bottleneck, preresnet_group, preresnet_basicblock, preresnet_bottleneck,
resnet_group, resnet_basicblock, resnet_bottleneck, resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck,
resnet_backbone, resnet_backbone, ImageNetModel,
eval_on_ILSVRC12, image_preprocess, compute_loss_and_error) eval_on_ILSVRC12)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
class Model(ModelDesc): class Model(ImageNetModel):
def __init__(self, depth, data_format='NCHW', preact=False): def __init__(self, depth, data_format='NCHW', mode='resnet'):
if data_format == 'NCHW': super(Model, self).__init__(data_format)
assert tf.test.is_gpu_available()
self.data_format = data_format
self.preact = preact
basicblock = preresnet_basicblock if preact else resnet_basicblock self.mode = mode
bottleneck = preresnet_bottleneck if preact else resnet_bottleneck basicblock = preresnet_basicblock if mode == 'preact' else resnet_basicblock
bottleneck = {
'resnet': resnet_bottleneck,
'preact': preresnet_bottleneck,
'se': se_resnet_bottleneck}[mode]
self.num_blocks, self.block_func = { self.num_blocks, self.block_func = {
18: ([2, 2, 2, 2], basicblock), 18: ([2, 2, 2, 2], basicblock),
34: ([3, 4, 6, 3], basicblock), 34: ([3, 4, 6, 3], basicblock),
...@@ -47,34 +45,11 @@ class Model(ModelDesc): ...@@ -47,34 +45,11 @@ class Model(ModelDesc):
152: ([3, 8, 36, 3], bottleneck) 152: ([3, 8, 36, 3], bottleneck)
}[depth] }[depth]
def _get_inputs(self): def get_logits(self, image):
# uint8 instead of float32 is used as input type to reduce copy overhead.
# It might hurt the performance a liiiitle bit.
# The pretrained models were trained with float32.
return [InputDesc(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=True)
if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2])
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format): with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
logits = resnet_backbone( return resnet_backbone(
image, self.num_blocks, image, self.num_blocks,
preresnet_group if self.preact else resnet_group, self.block_func) preresnet_group if self.mode == 'preact' else resnet_group, self.block_func)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 0.1, summary=True)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(name, batch): def get_data(name, batch):
...@@ -134,13 +109,17 @@ if __name__ == '__main__': ...@@ -134,13 +109,17 @@ if __name__ == '__main__':
parser.add_argument('-d', '--depth', help='resnet depth', parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101, 152]) type=int, default=18, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true') parser.add_argument('--eval', action='store_true')
parser.add_argument('--preact', action='store_true', help='Use pre-activation resnet') parser.add_argument('--mode', choices=['resnet', 'preact', 'se'],
help='variants of resnet to use', default='resnet')
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
model = Model(args.depth, args.data_format, args.preact) if args.mode == 'se':
assert args.depth >= 50
model = Model(args.depth, args.data_format, args.mode)
if args.eval: if args.eval:
batch = 128 # something that can run on one gpu batch = 128 # something that can run on one gpu
ds = get_data('val', batch) ds = get_data('val', batch)
......
...@@ -5,11 +5,12 @@ ...@@ -5,11 +5,12 @@
import numpy as np import numpy as np
import cv2 import cv2
import multiprocessing import multiprocessing
from abc import abstractmethod
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import imgaug, dataset from tensorpack import imgaug, dataset, ModelDesc, InputDesc
from tensorpack.dataflow import ( from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ, AugmentImageComponent, PrefetchDataZMQ,
BatchData, ThreadedMapData) BatchData, ThreadedMapData)
...@@ -17,8 +18,8 @@ from tensorpack.utils.stats import RatioCounter ...@@ -17,8 +18,8 @@ from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.argscope import argscope, get_arg_scope from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.models import ( from tensorpack.models import (
Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU, Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU, FullyConnected,
LinearWrap) LinearWrap, regularize_cost)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
...@@ -120,11 +121,8 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity): ...@@ -120,11 +121,8 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
def apply_preactivation(l, preact): def apply_preactivation(l, preact):
"""
'no_preact' for the first resblock in each group only, because the input is activated already.
'bnrelu' for all the non-first blocks, where identity mapping is preserved on shortcut path.
"""
if preact == 'bnrelu': if preact == 'bnrelu':
# this is used only for preact-resnet
shortcut = l # preserve identity mapping shortcut = l # preserve identity mapping
l = BNReLU('preact', l) l = BNReLU('preact', l)
else: else:
...@@ -186,13 +184,25 @@ def resnet_bottleneck(l, ch_out, stride, preact): ...@@ -186,13 +184,25 @@ def resnet_bottleneck(l, ch_out, stride, preact):
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False)) return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def se_resnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def resnet_group(l, name, block_func, features, count, stride): def resnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name): with tf.variable_scope(name):
for i in range(0, count): for i in range(0, count):
with tf.variable_scope('block{}'.format(i)): with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features, l = block_func(l, features,
stride if i == 0 else 1, stride if i == 0 else 1, 'no_preact')
'no_preact')
# end of each block need an activation # end of each block need an activation
l = tf.nn.relu(l) l = tf.nn.relu(l)
return l return l
...@@ -262,3 +272,46 @@ def compute_loss_and_error(logits, label): ...@@ -262,3 +272,46 @@ def compute_loss_and_error(logits, label):
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5') wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5')) add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
return loss return loss
class ImageNetModel(ModelDesc):
def __init__(self, data_format='NCHW', image_dtype=tf.uint8):
if data_format == 'NCHW':
assert tf.test.is_gpu_available()
self.data_format = data_format
# uint8 instead of float32 is used as input type to reduce copy overhead.
# It might hurt the performance a liiiitle bit.
# The pretrained models were trained with float32.
self.image_dtype = image_dtype
def _get_inputs(self):
return [InputDesc(self.image_dtype, [None, 224, 224, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image_preprocess(image, bgr=True)
if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2])
logits = self.get_logits(image)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', tf.contrib.layers.l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
@abstractmethod
def get_logits(self, image):
"""
Args:
image: 4D tensor of 224x224 in ``self.data_format``
Returns:
Bx1000 logits
"""
def _get_optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
tf.summary.scalar('learning_rate-summary', lr)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment