Commit de6a2fed authored by Yuxin Wu's avatar Yuxin Wu

More general resnet blocks.

parent 831e2402
......@@ -13,6 +13,6 @@ Unless you want to send a simple several lines of PR that can be easily merged,
1. We prefer to not have an example that is too similar to existing ones in terms of the tasks.
2. Examples have to be able to reproduce (preferrably in some mesurable metrics) published or well-known experiments and results.
2. Examples have to be able to reproduce (preferrably in some measurable metrics) published or well-known experiments and results.
* Please run `flake8 .` under the root of this repo to lint your code, and make sure the command produces no output.
......@@ -14,24 +14,15 @@ from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow, get_imagenet_tfdata
from resnet_model import (
preresnet_basicblock, preresnet_bottleneck, preresnet_group,
resnet_backbone, resnet_group,
resnet_basicblock, resnet_bottleneck, resnext_32x4d_bottleneck, se_resnet_bottleneck)
import resnet_model
from resnet_model import preact_group, resnet_backbone, resnet_group
class Model(ImageNetModel):
def __init__(self, depth, mode='resnet'):
if mode == 'se':
assert depth >= 50
self.mode = mode
basicblock = preresnet_basicblock if mode == 'preact' else resnet_basicblock
bottleneck = {
'resnet': resnet_bottleneck,
'resnext32x4d': resnext_32x4d_bottleneck,
'preact': preresnet_bottleneck,
'se': se_resnet_bottleneck}[mode]
basicblock = getattr(resnet_model, mode + '_basicblock', None)
bottleneck = getattr(resnet_model, mode + '_bottleneck', None)
self.num_blocks, self.block_func = {
18: ([2, 2, 2, 2], basicblock),
34: ([3, 4, 6, 3], basicblock),
......@@ -39,12 +30,14 @@ class Model(ImageNetModel):
101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck)
}[depth]
assert self.block_func is not None, \
"(mode={}, depth={}) not implemented!".format(mode, depth)
def get_logits(self, image):
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
return resnet_backbone(
image, self.num_blocks,
preresnet_group if self.mode == 'preact' else resnet_group, self.block_func)
preact_group if self.mode == 'preact' else resnet_group, self.block_func)
def get_config(model):
......
......@@ -16,15 +16,6 @@ def resnet_shortcut(l, n_out, stride, activation=tf.identity):
return l
def apply_preactivation(l, preact):
if preact == 'bnrelu':
shortcut = l # preserve identity mapping
l = BNReLU('preact', l)
else:
shortcut = l
return l, shortcut
def get_bn(zero_init=False):
"""
Zero init gamma is good for resnet. See https://arxiv.org/abs/1706.02677.
......@@ -35,14 +26,24 @@ def get_bn(zero_init=False):
return lambda x, name=None: BatchNorm('bn', x)
def preresnet_basicblock(l, ch_out, stride, preact):
# ----------------- pre-activation resnet ----------------------
def apply_preactivation(l, preact):
if preact == 'bnrelu':
shortcut = l # preserve identity mapping
l = BNReLU('preact', l)
else:
shortcut = l
return l, shortcut
def preact_basicblock(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 3, strides=stride, activation=BNReLU)
l = Conv2D('conv2', l, ch_out, 3)
return l + resnet_shortcut(shortcut, ch_out, stride)
def preresnet_bottleneck(l, ch_out, stride, preact):
def preact_bottleneck(l, ch_out, stride, preact):
# stride is applied on the second conv, following fb.resnet.torch
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU)
......@@ -51,7 +52,7 @@ def preresnet_bottleneck(l, ch_out, stride, preact):
return l + resnet_shortcut(shortcut, ch_out * 4, stride)
def preresnet_group(name, l, block_func, features, count, stride):
def preact_group(name, l, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
......@@ -62,6 +63,7 @@ def preresnet_group(name, l, block_func, features, count, stride):
# end of each group need an extra activation
l = BNReLU('bnlast', l)
return l
# ----------------- pre-activation resnet ----------------------
def resnet_basicblock(l, ch_out, stride):
......@@ -84,7 +86,7 @@ def resnet_bottleneck(l, ch_out, stride, stride_first=False):
return tf.nn.relu(out)
def se_resnet_bottleneck(l, ch_out, stride):
def se_bottleneck(l, ch_out, stride):
shortcut = l
l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU)
......@@ -102,7 +104,7 @@ def se_resnet_bottleneck(l, ch_out, stride):
return tf.nn.relu(out)
def resnext_32x4d_bottleneck(l, ch_out, stride):
def resnext32x4d_bottleneck(l, ch_out, stride):
shortcut = l
l = Conv2D('conv1', l, ch_out * 2, 1, strides=1, activation=BNReLU)
l = Conv2D('conv2', l, ch_out * 2, 3, strides=stride, activation=BNReLU, split=32)
......@@ -122,7 +124,7 @@ def resnet_group(name, l, block_func, features, count, stride):
def resnet_backbone(image, num_blocks, group_func, block_func):
with argscope(Conv2D, use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')):
# Note that this pads the image by [2, 3] instead of [3, 2].
# Note that TF pads the image by [2, 3] instead of [3, 2].
# Similar things happen in later stride=2 layers as well.
l = Conv2D('conv0', image, 64, 7, strides=2, activation=BNReLU)
l = MaxPooling('pool0', l, pool_size=3, strides=2, padding='SAME')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment