Commit 8591e253 authored by Yuxin Wu's avatar Yuxin Wu

[ResNet] add non-preact resnet

parent 6a143365
......@@ -139,9 +139,9 @@ def apply_preactivation(l, preact):
def get_bn(zero_init=False):
if zero_init:
return lambda x, _: BatchNorm('bn', x, gamma_init=tf.zeros_initializer())
return lambda x, name: BatchNorm('bn', x, gamma_init=tf.zeros_initializer())
else:
return lambda x, _: BatchNorm('bn', x)
return lambda x, name: BatchNorm('bn', x)
def preresnet_basicblock(l, ch_out, stride, preact):
......@@ -172,6 +172,27 @@ def preresnet_group(l, name, block_func, features, count, stride):
return l
def resnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def resnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
# first block doesn't need activation
l = block_func(l, features,
stride if i == 0 else 1,
'no_preact' if i == 0 else 'relu')
# end of each group need an extra activation
l = tf.nn.relu(l)
return l
def resnet_backbone(image, num_blocks, group_func, block_func):
with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment