Commit e086f05a authored by yselivonchyk's avatar yselivonchyk Committed by Yuxin Wu

ResNet-mixup example: align implementation with one, referenced by original paper. (#571)

* Align implementation with reference implementation used by paper.

ResNet-18 with preactivation as by https://github.com/kuangliu/pytorch-cifar is using ResNet with preactivation block with 2 consecutive convolution layers in the block. Existing implementation was using 3.

Weight decay was set incorrectly.

Architecture aligned with main repository approach: defined functions for bottleneck and regular PreActResNet blocks

Support for multiple depths added.

* PreActivation block: no BnRelu should appear outside of the residual branch

* Code migration clean up: blocks reareanged, variable names aligned

* Correct reference implementation: BnRelu is used in identity branch only before a convolutional layer.

* Updated model accuracies after sigle run

* Documentation update

* closer to mixup experiment settings

* fix lint
parent 26e609f8
...@@ -67,9 +67,8 @@ Reproduce the mixup pre-act ResNet-18 CIFAR10 experiment, in the paper: ...@@ -67,9 +67,8 @@ Reproduce the mixup pre-act ResNet-18 CIFAR10 experiment, in the paper:
* [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412). * [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412).
Please note that this preact18 architecture is This implementation follows exact settings from the [author's code](https://github.com/hongyi-zhang/mixup).
[different](https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py) Note that the architecture is different from the offcial preact-ResNet18.
from `cifar10-resnet18.py`.
Usage: Usage:
```bash ```bash
...@@ -77,7 +76,5 @@ Usage: ...@@ -77,7 +76,5 @@ Usage:
./cifar10-preact18-mixup.py --mixup # with mixup ./cifar10-preact18-mixup.py --mixup # with mixup
``` ```
Validation error with the original LR schedule (100-150-200): __5.0%__ without mixup, __3.8%__ with mixup. Results of the reference code can be reproduced.
This matches the number in the paper. In one run it gives me: 5.48% without mixup; __4.17%__ with mixup (alpha=1).
With 2x LR schedule: 4.7% without mixup, and 3.2% with mixup.
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar10-preact18-mixup.py # File: cifar10-preact18-mixup.py
# Author: Tao Hu <taohu620@gmail.com> # Author: Tao Hu <taohu620@gmail.com>, Yauheni Selivonchyk <y.selivonchyk@gmail.com>
import numpy as np import numpy as np
import argparse import argparse
import os import os
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
import tensorflow as tf BATCH_SIZE = 128
from tensorflow.contrib.layers import variance_scaling_initializer CLASS_NUM = 10
""" LR_SCHEDULE = [(0, 0.1), (100, 0.01), (150, 0.001)]
This implementation uses the architecture of PreAct in: WEIGHT_DECAY = 1e-4
https://github.com/kuangliu/pytorch-cifar
This is different from the one in cifar10-resnet.py
Results: FILTER_SIZES = [64, 128, 256, 512]
Validation error with the original 100-150-200 schedule: MODULE_SIZES = [2, 2, 2, 2]
no mixup - 5.0%; mixup(alpha=1) - 3.8%
Using 2x learning schedule, it can further improve to 4.7% and 3.2%.
Usage: def preactivation_block(input, num_filters, stride=1):
./cifar10-preact18-mixup.py # train without mixup num_filters_in = input.get_shape().as_list()[1]
./cifar10-preact18-mixup.py --mixup # with mixup
"""
BATCH_SIZE = 128 # residual
CLASS_NUM = 10 net = BNReLU(input)
residual = Conv2D('conv1', net, num_filters, kernel_shape=3, stride=stride, use_bias=False, nl=BNReLU)
residual = Conv2D('conv2', residual, num_filters, kernel_shape=3, stride=1, use_bias=False, nl=tf.identity)
# identity
shortcut = input
if stride != 1 or num_filters_in != num_filters:
shortcut = Conv2D('shortcut', net, num_filters, kernel_shape=1, stride=stride, use_bias=False,
nl=tf.identity)
return shortcut + residual
class Model(ModelDesc): class ResNet_Cifar(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.float32, [None, 32, 32, 3], 'input'), return [InputDesc(tf.float32, [None, 32, 32, 3], 'input'),
InputDesc(tf.float32, [None, CLASS_NUM], 'label')] InputDesc(tf.float32, [None, CLASS_NUM], 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs
image = image / 128.0
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
image = tf.transpose(image, [0, 3, 1, 2]) image, label = inputs
def preactblock(input, name, in_planes, planes, stride=1):
with tf.variable_scope(name):
input2 = BNReLU(input)
if stride != 1 or in_planes != planes:
shortcut = Conv2D('shortcut', input2, planes, kernel_shape=1, stride=stride, use_bias=False,
nl=tf.identity)
else:
shortcut = input
input2 = Conv2D('conv1', input2, planes, kernel_shape=3, stride=1, use_bias=False, nl=BNReLU)
input2 = Conv2D('conv2', input2, planes, kernel_shape=3, stride=stride, use_bias=False, nl=BNReLU)
input2 = Conv2D('conv3', input2, planes, kernel_shape=3, stride=1, use_bias=False, nl=tf.identity)
input2 += shortcut
return input2
def _make_layer(input, planes, num_blocks, current_plane, stride, name):
strides = [stride] + [1] * (num_blocks - 1) # first block stride = stride, the latter block stride = 1
for index, stride in enumerate(strides):
input = preactblock(input, "{}.{}".format(name, index), current_plane, planes, stride)
current_plane = planes
return input, current_plane
with argscope([Conv2D, AvgPooling, BatchNorm, GlobalAvgPooling], data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3,
W_init=variance_scaling_initializer(mode='FAN_OUT')):
l = Conv2D('conv0', image, 64, kernel_shape=3, stride=1, use_bias=False)
current_plane = 64
l, current_plane = _make_layer(l, 64, 2, current_plane, stride=1, name="res1")
l, current_plane = _make_layer(l, 128, 2, current_plane, stride=2, name="res2")
l, current_plane = _make_layer(l, 256, 2, current_plane, stride=2, name="res3")
l, current_plane = _make_layer(l, 512, 2, current_plane, stride=2, name="res4")
l = GlobalAvgPooling('gap', l)
logits = FullyConnected('linear', l, out_dim=CLASS_NUM, nl=tf.identity) MEAN_IMAGE = tf.constant([0.4914, 0.4822, 0.4465], dtype=tf.float32)
STD_IMAGE = tf.constant([0.2023, 0.1994, 0.2010], dtype=tf.float32)
image = ((image / 255.0) - MEAN_IMAGE) / STD_IMAGE
image = tf.transpose(image, [0, 3, 1, 2])
cost = tf.losses.softmax_cross_entropy(onehot_labels=label, logits=logits) pytorch_default_init = tf.variance_scaling_initializer(scale=1.0 / 3, mode='fan_in', distribution='uniform')
cost = tf.reduce_mean(cost, name='cross_entropy_loss') with argscope([Conv2D, BatchNorm, GlobalAvgPooling], data_format='NCHW'), \
argscope(Conv2D, W_init=pytorch_default_init):
net = Conv2D('conv0', image, 64, kernel_shape=3, stride=1, use_bias=False)
for i, blocks_in_module in enumerate(MODULE_SIZES):
for j in range(blocks_in_module):
stride = 2 if j == 0 and i > 0 else 1
with tf.variable_scope("res%d.%d" % (i, j)):
net = preactivation_block(net, FILTER_SIZES[i], stride)
net = GlobalAvgPooling('gap', net)
logits = FullyConnected('linear', net, out_dim=CLASS_NUM,
nl=tf.identity, W_init=tf.random_normal_initializer(stddev=1e-3))
ce_cost = tf.nn.softmax_cross_entropy_with_logits(labels=label, logits=logits)
ce_cost = tf.reduce_mean(ce_cost, name='cross_entropy_loss')
single_label = tf.to_int32(tf.argmax(label, axis=1)) single_label = tf.to_int32(tf.argmax(label, axis=1))
wrong = tf.to_float(tf.logical_not(tf.nn.in_top_k(logits, single_label, 1)), name='wrong_vector') wrong = tf.to_float(tf.logical_not(tf.nn.in_top_k(logits, single_label, 1)), name='wrong_vector')
# monitor training error # monitor training error
add_moving_summary(tf.reduce_mean(wrong, name='train_error')) add_moving_summary(tf.reduce_mean(wrong, name='train_error'), ce_cost)
add_param_summary(('.*/W', ['histogram']))
# weight decay on all W of fc layers # weight decay on all W matrixes. including convolutional layers
wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(), wd_cost = tf.multiply(WEIGHT_DECAY, regularize_cost('.*', tf.nn.l2_loss), name='wd_cost')
480000, 0.2, True)
wd_cost = tf.multiply(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
add_moving_summary(cost, wd_cost)
add_param_summary(('.*/W', ['histogram'])) # monitor W self.cost = tf.add_n([ce_cost, wd_cost], name='cost')
self.cost = tf.add_n([cost, wd_cost], name='cost')
def _get_optimizer(self): def _get_optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.01, trainable=False) lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
opt = tf.train.MomentumOptimizer(lr, 0.9) opt = tf.train.MomentumOptimizer(lr, 0.9)
return opt return opt
...@@ -110,23 +89,14 @@ class Model(ModelDesc): ...@@ -110,23 +89,14 @@ class Model(ModelDesc):
def get_data(train_or_test, isMixup, alpha): def get_data(train_or_test, isMixup, alpha):
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
ds = dataset.Cifar10(train_or_test) ds = dataset.Cifar10(train_or_test)
pp_mean = ds.get_per_pixel_mean()
if isTrain: if isTrain:
augmentors = [ augmentors = [
imgaug.CenterPaste((40, 40)), imgaug.CenterPaste((40, 40)),
imgaug.RandomCrop((32, 32)), imgaug.RandomCrop((32, 32)),
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
imgaug.MapImage(lambda x: x - pp_mean),
]
else:
augmentors = [
imgaug.MapImage(lambda x: x - pp_mean)
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
if isMixup:
batch = 2 * BATCH_SIZE
else:
batch = BATCH_SIZE batch = BATCH_SIZE
ds = BatchData(ds, batch, remainder=not isTrain) ds = BatchData(ds, batch, remainder=not isTrain)
...@@ -140,16 +110,15 @@ def get_data(train_or_test, isMixup, alpha): ...@@ -140,16 +110,15 @@ def get_data(train_or_test, isMixup, alpha):
weight = np.random.beta(alpha, alpha, BATCH_SIZE) weight = np.random.beta(alpha, alpha, BATCH_SIZE)
x_weight = weight.reshape(BATCH_SIZE, 1, 1, 1) x_weight = weight.reshape(BATCH_SIZE, 1, 1, 1)
y_weight = weight.reshape(BATCH_SIZE, 1) y_weight = weight.reshape(BATCH_SIZE, 1)
x1, x2 = np.split(images, 2, axis=0) index = np.random.permutation(BATCH_SIZE)
x1, x2 = images, images[index]
x = x1 * x_weight + x2 * (1 - x_weight) x = x1 * x_weight + x2 * (1 - x_weight)
y1, y2 = np.split(one_hot_labels, 2, axis=0) y1, y2 = one_hot_labels, one_hot_labels[index]
y = y1 * y_weight + y2 * (1 - y_weight) y = y1 * y_weight + y2 * (1 - y_weight)
return [x, y] return [x, y]
ds = MapData(ds, f) ds = MapData(ds, f)
if isTrain:
ds = PrefetchData(ds, 3, 2)
return ds return ds
...@@ -164,30 +133,25 @@ if __name__ == '__main__': ...@@ -164,30 +133,25 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.set_logger_dir( log_foder = 'train_log/cifar10-preact18%s' % ('-mixup' if args.mixup else '')
os.path.join('train_log/cifar10-preact18-{}mixup'.format('' if args.mixup else 'no'))) logger.set_logger_dir(os.path.join(log_foder))
dataset_train = get_data('train', args.mixup, args.alpha) dataset_train = get_data('train', args.mixup, args.alpha)
dataset_test = get_data('test', args.mixup, args.alpha) dataset_test = get_data('test', args.mixup, args.alpha)
steps_per_epoch = dataset_train.size() steps_per_epoch = dataset_train.size()
# because mixup utilize two data to generate one data, so the learning rate schedule are doubled.
if args.mixup:
steps_per_epoch *= 2
config = TrainConfig( config = TrainConfig(
model=Model(), model=ResNet_Cifar(),
dataflow=dataset_train, data=QueueInput(dataset_train),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner(dataset_test, InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError('wrong_vector')]), [ScalarStats('cost'), ClassificationError('wrong_vector')]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate', LR_SCHEDULE)
[(1, 0.1), (100, 0.01), (150, 0.001)])
], ],
max_epoch=200, max_epoch=200,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
nr_gpu = max(get_nr_gpu(), 1) launch_train_with_config(config, SimpleTrainer())
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpu))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment