Commit de6a2fed authored by Yuxin Wu's avatar Yuxin Wu

More general resnet blocks.

parent 831e2402
...@@ -2,7 +2,7 @@ Thanks for your contribution! ...@@ -2,7 +2,7 @@ Thanks for your contribution!
Unless you want to send a simple several lines of PR that can be easily merged, please note the following: Unless you want to send a simple several lines of PR that can be easily merged, please note the following:
* If you want to add a new feature, * If you want to add a new feature,
please open an issue first and indicate that you want to contribute. please open an issue first and indicate that you want to contribute.
There are features that we prefer to not add to tensorpack, e.g. symbolic models There are features that we prefer to not add to tensorpack, e.g. symbolic models
...@@ -13,6 +13,6 @@ Unless you want to send a simple several lines of PR that can be easily merged, ...@@ -13,6 +13,6 @@ Unless you want to send a simple several lines of PR that can be easily merged,
1. We prefer to not have an example that is too similar to existing ones in terms of the tasks. 1. We prefer to not have an example that is too similar to existing ones in terms of the tasks.
2. Examples have to be able to reproduce (preferrably in some mesurable metrics) published or well-known experiments and results. 2. Examples have to be able to reproduce (preferrably in some measurable metrics) published or well-known experiments and results.
* Please run `flake8 .` under the root of this repo to lint your code, and make sure the command produces no output. * Please run `flake8 .` under the root of this repo to lint your code, and make sure the command produces no output.
...@@ -14,24 +14,15 @@ from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_ ...@@ -14,24 +14,15 @@ from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow, get_imagenet_tfdata from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow, get_imagenet_tfdata
from resnet_model import ( import resnet_model
preresnet_basicblock, preresnet_bottleneck, preresnet_group, from resnet_model import preact_group, resnet_backbone, resnet_group
resnet_backbone, resnet_group,
resnet_basicblock, resnet_bottleneck, resnext_32x4d_bottleneck, se_resnet_bottleneck)
class Model(ImageNetModel): class Model(ImageNetModel):
def __init__(self, depth, mode='resnet'): def __init__(self, depth, mode='resnet'):
if mode == 'se':
assert depth >= 50
self.mode = mode self.mode = mode
basicblock = preresnet_basicblock if mode == 'preact' else resnet_basicblock basicblock = getattr(resnet_model, mode + '_basicblock', None)
bottleneck = { bottleneck = getattr(resnet_model, mode + '_bottleneck', None)
'resnet': resnet_bottleneck,
'resnext32x4d': resnext_32x4d_bottleneck,
'preact': preresnet_bottleneck,
'se': se_resnet_bottleneck}[mode]
self.num_blocks, self.block_func = { self.num_blocks, self.block_func = {
18: ([2, 2, 2, 2], basicblock), 18: ([2, 2, 2, 2], basicblock),
34: ([3, 4, 6, 3], basicblock), 34: ([3, 4, 6, 3], basicblock),
...@@ -39,12 +30,14 @@ class Model(ImageNetModel): ...@@ -39,12 +30,14 @@ class Model(ImageNetModel):
101: ([3, 4, 23, 3], bottleneck), 101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck) 152: ([3, 8, 36, 3], bottleneck)
}[depth] }[depth]
assert self.block_func is not None, \
"(mode={}, depth={}) not implemented!".format(mode, depth)
def get_logits(self, image): def get_logits(self, image):
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format): with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
return resnet_backbone( return resnet_backbone(
image, self.num_blocks, image, self.num_blocks,
preresnet_group if self.mode == 'preact' else resnet_group, self.block_func) preact_group if self.mode == 'preact' else resnet_group, self.block_func)
def get_config(model): def get_config(model):
......
...@@ -16,15 +16,6 @@ def resnet_shortcut(l, n_out, stride, activation=tf.identity): ...@@ -16,15 +16,6 @@ def resnet_shortcut(l, n_out, stride, activation=tf.identity):
return l return l
def apply_preactivation(l, preact):
if preact == 'bnrelu':
shortcut = l # preserve identity mapping
l = BNReLU('preact', l)
else:
shortcut = l
return l, shortcut
def get_bn(zero_init=False): def get_bn(zero_init=False):
""" """
Zero init gamma is good for resnet. See https://arxiv.org/abs/1706.02677. Zero init gamma is good for resnet. See https://arxiv.org/abs/1706.02677.
...@@ -35,14 +26,24 @@ def get_bn(zero_init=False): ...@@ -35,14 +26,24 @@ def get_bn(zero_init=False):
return lambda x, name=None: BatchNorm('bn', x) return lambda x, name=None: BatchNorm('bn', x)
def preresnet_basicblock(l, ch_out, stride, preact): # ----------------- pre-activation resnet ----------------------
def apply_preactivation(l, preact):
if preact == 'bnrelu':
shortcut = l # preserve identity mapping
l = BNReLU('preact', l)
else:
shortcut = l
return l, shortcut
def preact_basicblock(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact) l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 3, strides=stride, activation=BNReLU) l = Conv2D('conv1', l, ch_out, 3, strides=stride, activation=BNReLU)
l = Conv2D('conv2', l, ch_out, 3) l = Conv2D('conv2', l, ch_out, 3)
return l + resnet_shortcut(shortcut, ch_out, stride) return l + resnet_shortcut(shortcut, ch_out, stride)
def preresnet_bottleneck(l, ch_out, stride, preact): def preact_bottleneck(l, ch_out, stride, preact):
# stride is applied on the second conv, following fb.resnet.torch # stride is applied on the second conv, following fb.resnet.torch
l, shortcut = apply_preactivation(l, preact) l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU) l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU)
...@@ -51,7 +52,7 @@ def preresnet_bottleneck(l, ch_out, stride, preact): ...@@ -51,7 +52,7 @@ def preresnet_bottleneck(l, ch_out, stride, preact):
return l + resnet_shortcut(shortcut, ch_out * 4, stride) return l + resnet_shortcut(shortcut, ch_out * 4, stride)
def preresnet_group(name, l, block_func, features, count, stride): def preact_group(name, l, block_func, features, count, stride):
with tf.variable_scope(name): with tf.variable_scope(name):
for i in range(0, count): for i in range(0, count):
with tf.variable_scope('block{}'.format(i)): with tf.variable_scope('block{}'.format(i)):
...@@ -62,6 +63,7 @@ def preresnet_group(name, l, block_func, features, count, stride): ...@@ -62,6 +63,7 @@ def preresnet_group(name, l, block_func, features, count, stride):
# end of each group need an extra activation # end of each group need an extra activation
l = BNReLU('bnlast', l) l = BNReLU('bnlast', l)
return l return l
# ----------------- pre-activation resnet ----------------------
def resnet_basicblock(l, ch_out, stride): def resnet_basicblock(l, ch_out, stride):
...@@ -84,7 +86,7 @@ def resnet_bottleneck(l, ch_out, stride, stride_first=False): ...@@ -84,7 +86,7 @@ def resnet_bottleneck(l, ch_out, stride, stride_first=False):
return tf.nn.relu(out) return tf.nn.relu(out)
def se_resnet_bottleneck(l, ch_out, stride): def se_bottleneck(l, ch_out, stride):
shortcut = l shortcut = l
l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU) l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU) l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU)
...@@ -102,7 +104,7 @@ def se_resnet_bottleneck(l, ch_out, stride): ...@@ -102,7 +104,7 @@ def se_resnet_bottleneck(l, ch_out, stride):
return tf.nn.relu(out) return tf.nn.relu(out)
def resnext_32x4d_bottleneck(l, ch_out, stride): def resnext32x4d_bottleneck(l, ch_out, stride):
shortcut = l shortcut = l
l = Conv2D('conv1', l, ch_out * 2, 1, strides=1, activation=BNReLU) l = Conv2D('conv1', l, ch_out * 2, 1, strides=1, activation=BNReLU)
l = Conv2D('conv2', l, ch_out * 2, 3, strides=stride, activation=BNReLU, split=32) l = Conv2D('conv2', l, ch_out * 2, 3, strides=stride, activation=BNReLU, split=32)
...@@ -122,7 +124,7 @@ def resnet_group(name, l, block_func, features, count, stride): ...@@ -122,7 +124,7 @@ def resnet_group(name, l, block_func, features, count, stride):
def resnet_backbone(image, num_blocks, group_func, block_func): def resnet_backbone(image, num_blocks, group_func, block_func):
with argscope(Conv2D, use_bias=False, with argscope(Conv2D, use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')): kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')):
# Note that this pads the image by [2, 3] instead of [3, 2]. # Note that TF pads the image by [2, 3] instead of [3, 2].
# Similar things happen in later stride=2 layers as well. # Similar things happen in later stride=2 layers as well.
l = Conv2D('conv0', image, 64, 7, strides=2, activation=BNReLU) l = Conv2D('conv0', image, 64, 7, strides=2, activation=BNReLU)
l = MaxPooling('pool0', l, pool_size=3, strides=2, padding='SAME') l = MaxPooling('pool0', l, pool_size=3, strides=2, padding='SAME')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment