Commit b5b2d4a0 authored by Yuxin Wu's avatar Yuxin Wu

cleanup the use of preact

parent 760a8112
...@@ -23,7 +23,6 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity): ...@@ -23,7 +23,6 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
def apply_preactivation(l, preact): def apply_preactivation(l, preact):
if preact == 'bnrelu': if preact == 'bnrelu':
# this is used only for preact-resnet
shortcut = l # preserve identity mapping shortcut = l # preserve identity mapping
l = BNReLU('preact', l) l = BNReLU('preact', l)
else: else:
...@@ -70,26 +69,26 @@ def preresnet_group(l, name, block_func, features, count, stride): ...@@ -70,26 +69,26 @@ def preresnet_group(l, name, block_func, features, count, stride):
return l return l
def resnet_basicblock(l, ch_out, stride, preact): def resnet_basicblock(l, ch_out, stride):
l, shortcut = apply_preactivation(l, preact) shortcut = l
l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU) l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, nl=get_bn(zero_init=True)) l = Conv2D('conv2', l, ch_out, 3, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out, stride, nl=get_bn(zero_init=False)) return l + resnet_shortcut(shortcut, ch_out, stride, nl=get_bn(zero_init=False))
def resnet_bottleneck(l, ch_out, stride, preact, stride_first=False): def resnet_bottleneck(l, ch_out, stride, stride_first=False):
""" """
stride_first: original resnet put stride on first conv. fb.resnet.torch put stride on second conv. stride_first: original resnet put stride on first conv. fb.resnet.torch put stride on second conv.
""" """
l, shortcut = apply_preactivation(l, preact) shortcut = l
l = Conv2D('conv1', l, ch_out, 1, stride=stride if stride_first else 1, nl=BNReLU) l = Conv2D('conv1', l, ch_out, 1, stride=stride if stride_first else 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=1 if stride_first else stride, nl=BNReLU) l = Conv2D('conv2', l, ch_out, 3, stride=1 if stride_first else stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True)) l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False)) return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def se_resnet_bottleneck(l, ch_out, stride, preact): def se_resnet_bottleneck(l, ch_out, stride):
l, shortcut = apply_preactivation(l, preact) shortcut = l
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU) l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU) l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True)) l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
...@@ -105,8 +104,7 @@ def resnet_group(l, name, block_func, features, count, stride): ...@@ -105,8 +104,7 @@ def resnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name): with tf.variable_scope(name):
for i in range(0, count): for i in range(0, count):
with tf.variable_scope('block{}'.format(i)): with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features, l = block_func(l, features, stride if i == 0 else 1)
stride if i == 0 else 1, 'no_preact')
# end of each block need an activation # end of each block need an activation
l = tf.nn.relu(l) l = tf.nn.relu(l)
return l return l
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment