Commit 6a143365 authored by Yuxin Wu's avatar Yuxin Wu

[ResNet] Prefix block_func with preresnet. Use group_func.

parent 466b5192
...@@ -18,7 +18,7 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -18,7 +18,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import ( from imagenet_resnet_utils import (
fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone, fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone,
eval_on_ILSVRC12, image_preprocess, compute_loss_and_error, preresnet_group, eval_on_ILSVRC12, image_preprocess, compute_loss_and_error,
get_imagenet_dataflow) get_imagenet_dataflow)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
...@@ -55,7 +55,7 @@ class Model(ModelDesc): ...@@ -55,7 +55,7 @@ class Model(ModelDesc):
defs = RESNET_CONFIG[DEPTH] defs = RESNET_CONFIG[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'): with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
logits = resnet_backbone(image, defs, bottleneck_se) logits = resnet_backbone(image, defs, preresnet_group, bottleneck_se)
loss = compute_loss_and_error(logits, label) loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss') wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
......
...@@ -20,7 +20,8 @@ from tensorpack.tfutils import argscope, get_model_loader ...@@ -20,7 +20,8 @@ from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import ( from imagenet_resnet_utils import (
fbresnet_augmentor, resnet_basicblock, resnet_bottleneck, resnet_backbone, fbresnet_augmentor, preresnet_group,
preresnet_basicblock, preresnet_bottleneck, resnet_backbone,
eval_on_ILSVRC12, image_preprocess, compute_loss_and_error, eval_on_ILSVRC12, image_preprocess, compute_loss_and_error,
get_imagenet_dataflow) get_imagenet_dataflow)
...@@ -29,10 +30,10 @@ INPUT_SHAPE = 224 ...@@ -29,10 +30,10 @@ INPUT_SHAPE = 224
DEPTH = None DEPTH = None
RESNET_CONFIG = { RESNET_CONFIG = {
18: ([2, 2, 2, 2], resnet_basicblock), 18: ([2, 2, 2, 2], preresnet_basicblock),
34: ([3, 4, 6, 3], resnet_basicblock), 34: ([3, 4, 6, 3], preresnet_basicblock),
50: ([3, 4, 6, 3], resnet_bottleneck), 50: ([3, 4, 6, 3], preresnet_bottleneck),
101: ([3, 4, 23, 3], resnet_bottleneck) 101: ([3, 4, 23, 3], preresnet_bottleneck)
} }
...@@ -58,7 +59,7 @@ class Model(ModelDesc): ...@@ -58,7 +59,7 @@ class Model(ModelDesc):
defs, block_func = RESNET_CONFIG[DEPTH] defs, block_func = RESNET_CONFIG[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format): with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
logits = resnet_backbone(image, defs, block_func) logits = resnet_backbone(image, defs, preresnet_group, block_func)
loss = compute_loss_and_error(logits, label) loss = compute_loss_and_error(logits, label)
......
...@@ -124,24 +124,34 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity): ...@@ -124,24 +124,34 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
def apply_preactivation(l, preact): def apply_preactivation(l, preact):
""" """
'no_preact' for the first resblock in each group only, because the input is activated already. 'no_preact' for the first resblock in each group only, because the input is activated already.
'default' for all the non-first blocks, where identity mapping is preserved on shortcut path. 'bnrelu/relu' for all the non-first blocks, where identity mapping is preserved on shortcut path.
""" """
if preact == 'default': if preact == 'bnrelu':
shortcut = l shortcut = l
l = BNReLU('preact', l) l = BNReLU('preact', l)
elif preact == 'relu':
shortcut = l
l = tf.nn.relu(l)
else: else:
shortcut = l shortcut = l
return l, shortcut return l, shortcut
def resnet_basicblock(l, ch_out, stride, preact): def get_bn(zero_init=False):
if zero_init:
return lambda x, _: BatchNorm('bn', x, gamma_init=tf.zeros_initializer())
else:
return lambda x, _: BatchNorm('bn', x)
def preresnet_basicblock(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact) l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU) l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3) l = Conv2D('conv2', l, ch_out, 3)
return l + resnet_shortcut(shortcut, ch_out, stride) return l + resnet_shortcut(shortcut, ch_out, stride)
def resnet_bottleneck(l, ch_out, stride, preact): def preresnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact) l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU) l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU) l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
...@@ -156,22 +166,22 @@ def preresnet_group(l, name, block_func, features, count, stride): ...@@ -156,22 +166,22 @@ def preresnet_group(l, name, block_func, features, count, stride):
# first block doesn't need activation # first block doesn't need activation
l = block_func(l, features, l = block_func(l, features,
stride if i == 0 else 1, stride if i == 0 else 1,
'no_preact' if i == 0 else 'default') 'no_preact' if i == 0 else 'bnrelu')
# end of each group need an extra activation # end of each group need an extra activation
l = BNReLU('bnlast', l) l = BNReLU('bnlast', l)
return l return l
def resnet_backbone(image, num_blocks, block_func): def resnet_backbone(image, num_blocks, group_func, block_func):
with argscope(Conv2D, nl=tf.identity, use_bias=False, with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')): W_init=variance_scaling_initializer(mode='FAN_OUT')):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU) .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME') .MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(preresnet_group, 'group0', block_func, 64, num_blocks[0], 1) .apply(group_func, 'group0', block_func, 64, num_blocks[0], 1)
.apply(preresnet_group, 'group1', block_func, 128, num_blocks[1], 2) .apply(group_func, 'group1', block_func, 128, num_blocks[1], 2)
.apply(preresnet_group, 'group2', block_func, 256, num_blocks[2], 2) .apply(group_func, 'group2', block_func, 256, num_blocks[2], 2)
.apply(preresnet_group, 'group3', block_func, 512, num_blocks[3], 2) .apply(group_func, 'group3', block_func, 512, num_blocks[3], 2)
.GlobalAvgPooling('gap') .GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)()) .FullyConnected('linear', 1000, nl=tf.identity)())
return logits return logits
......
...@@ -20,7 +20,7 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -20,7 +20,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils import viz from tensorpack.utils import viz
from imagenet_resnet_utils import ( from imagenet_resnet_utils import (
fbresnet_augmentor, resnet_basicblock, preresnet_group, fbresnet_augmentor, preresnet_basicblock, preresnet_group,
image_preprocess, compute_loss_and_error) image_preprocess, compute_loss_and_error)
...@@ -40,8 +40,8 @@ class Model(ModelDesc): ...@@ -40,8 +40,8 @@ class Model(ModelDesc):
image = tf.transpose(image, [0, 3, 1, 2]) image = tf.transpose(image, [0, 3, 1, 2])
cfg = { cfg = {
18: ([2, 2, 2, 2], resnet_basicblock), 18: ([2, 2, 2, 2], preresnet_basicblock),
34: ([3, 4, 6, 3], resnet_basicblock), 34: ([3, 4, 6, 3], preresnet_basicblock),
} }
defs, block_func = cfg[DEPTH] defs, block_func = cfg[DEPTH]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment