Commit 5f56f6a5 authored by Mikhail Sirotenko's avatar Mikhail Sirotenko Committed by Yuxin Wu

Fix se_resnet_bottleneck to support NHWC data format (#514)

* Fix se_resnet_bottleneck to support NHWC data format

* Fixed my fix: removed original line

* Remove trailing whitespace
parent 53a41ae8
...@@ -96,7 +96,11 @@ def se_resnet_bottleneck(l, ch_out, stride): ...@@ -96,7 +96,11 @@ def se_resnet_bottleneck(l, ch_out, stride):
squeeze = GlobalAvgPooling('gap', l) squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu) squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid) squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1]) data_format = get_arg_scope()['Conv2D']['data_format']
ch_ax = 1 if data_format == 'NCHW' else 3
shape = [-1, 1, 1, 1]
shape[ch_ax] = ch_out * 4
l = l * tf.reshape(squeeze, shape)
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False)) return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment