Commit b5b2d4a0 authored by Yuxin Wu's avatar Yuxin Wu

cleanup the use of preact

parent 760a8112
......@@ -23,7 +23,6 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
def apply_preactivation(l, preact):
if preact == 'bnrelu':
# this is used only for preact-resnet
shortcut = l # preserve identity mapping
l = BNReLU('preact', l)
else:
......@@ -70,26 +69,26 @@ def preresnet_group(l, name, block_func, features, count, stride):
return l
def resnet_basicblock(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
def resnet_basicblock(l, ch_out, stride):
shortcut = l
l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out, stride, nl=get_bn(zero_init=False))
def resnet_bottleneck(l, ch_out, stride, preact, stride_first=False):
def resnet_bottleneck(l, ch_out, stride, stride_first=False):
"""
stride_first: original resnet put stride on first conv. fb.resnet.torch put stride on second conv.
"""
l, shortcut = apply_preactivation(l, preact)
shortcut = l
l = Conv2D('conv1', l, ch_out, 1, stride=stride if stride_first else 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=1 if stride_first else stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def se_resnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
def se_resnet_bottleneck(l, ch_out, stride):
shortcut = l
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
......@@ -105,8 +104,7 @@ def resnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features,
stride if i == 0 else 1, 'no_preact')
l = block_func(l, features, stride if i == 0 else 1)
# end of each block need an activation
l = tf.nn.relu(l)
return l
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment