Commit 0d20032a authored by Yuxin Wu's avatar Yuxin Wu

simplify resnet BNReLU

parent 5828a161
...@@ -44,7 +44,4 @@ The validation error here is computed on test set. ...@@ -44,7 +44,4 @@ The validation error here is computed on test set.
![cifar10](cifar10-resnet.png) ![cifar10](cifar10-resnet.png)
Download model:
[Cifar10 ResNet-110 (n=18)](https://drive.google.com/open?id=0B9IPQTvr2BBkTXBlZmh1cmlnQ0k)
Also see an implementation of [DenseNet](https://github.com/YixuanLi/densenet-tensorflow) from [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993). Also see an implementation of [DenseNet](https://github.com/YixuanLi/densenet-tensorflow) from [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993).
...@@ -54,16 +54,9 @@ class Model(ModelDesc): ...@@ -54,16 +54,9 @@ class Model(ModelDesc):
stride1 = 1 stride1 = 1
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
if not first: b1 = l if first else BNReLU(l)
b1 = BatchNorm('bn1', l) c1 = Conv2D('conv1', b1, out_channel, stride=stride1, nl=BNReLU)
b1 = tf.nn.relu(b1) c2 = Conv2D('conv2', c1, out_channel)
else:
b1 = l
c1 = Conv2D('conv1', b1, out_channel, stride=stride1)
b2 = BatchNorm('bn2', c1)
b2 = tf.nn.relu(b2)
c2 = Conv2D('conv2', b2, out_channel)
if increase_dim: if increase_dim:
l = AvgPooling('pool', l, 2) l = AvgPooling('pool', l, 2)
l = tf.pad(l, [[0,0], [0,0], [0,0], [in_channel//2, in_channel//2]]) l = tf.pad(l, [[0,0], [0,0], [0,0], [in_channel//2, in_channel//2]])
...@@ -73,9 +66,7 @@ class Model(ModelDesc): ...@@ -73,9 +66,7 @@ class Model(ModelDesc):
with argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3, with argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3,
W_init=variance_scaling_initializer(mode='FAN_OUT')): W_init=variance_scaling_initializer(mode='FAN_OUT')):
l = Conv2D('conv0', image, 16) l = Conv2D('conv0', image, 16, nl=BNReLU)
l = BatchNorm('bn0', l)
l = tf.nn.relu(l)
l = residual('res1.0', l, first=True) l = residual('res1.0', l, first=True)
for k in range(1, self.n): for k in range(1, self.n):
l = residual('res1.{}'.format(k), l) l = residual('res1.{}'.format(k), l)
...@@ -89,8 +80,7 @@ class Model(ModelDesc): ...@@ -89,8 +80,7 @@ class Model(ModelDesc):
l = residual('res3.0', l, increase_dim=True) l = residual('res3.0', l, increase_dim=True)
for k in range(1, self.n): for k in range(1, self.n):
l = residual('res3.' + str(k), l) l = residual('res3.' + str(k), l)
l = BatchNorm('bnlast', l) l = BNReLU('bnlast', l)
l = tf.nn.relu(l)
# 8,c=64 # 8,c=64
l = GlobalAvgPooling('gap', l) l = GlobalAvgPooling('gap', l)
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: imagenet-resnet.py # File: imagenet-resnet-short.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import cv2 import cv2
...@@ -42,36 +42,30 @@ class Model(ModelDesc): ...@@ -42,36 +42,30 @@ class Model(ModelDesc):
def basicblock(l, ch_out, stride, preact): def basicblock(l, ch_out, stride, preact):
ch_in = l.get_shape().as_list()[-1] ch_in = l.get_shape().as_list()[-1]
input = l
if preact == 'both_preact': if preact == 'both_preact':
l = BatchNorm('preact', l) l = BNReLU('preact', l)
l = tf.nn.relu(l, name='preact-relu')
input = l input = l
elif preact != 'no_preact': elif preact != 'no_preact':
l = BatchNorm('preact', l) input = l
l = tf.nn.relu(l, name='preact-relu') l = BNReLU('preact', l)
l = Conv2D('conv1', l, ch_out, 3, stride=stride) else:
l = BatchNorm('bn', l) input = l
l = tf.nn.relu(l) l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3) l = Conv2D('conv2', l, ch_out, 3)
return l + shortcut(input, ch_in, ch_out, stride) return l + shortcut(input, ch_in, ch_out, stride)
def bottleneck(l, ch_out, stride, preact): def bottleneck(l, ch_out, stride, preact):
ch_in = l.get_shape().as_list()[-1] ch_in = l.get_shape().as_list()[-1]
input = l
if preact == 'both_preact': if preact == 'both_preact':
l = BatchNorm('preact', l) l = BNReLU('preact', l)
l = tf.nn.relu(l, name='preact-relu')
input = l input = l
elif preact != 'no_preact': elif preact != 'no_preact':
l = BatchNorm('preact', l) input = l
l = tf.nn.relu(l, name='preact-relu') l = BNReLU('preact', l)
l = Conv2D('conv1', l, ch_out, 1) else:
l = BatchNorm('bn1', l) input = l
l = tf.nn.relu(l) l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride) l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = BatchNorm('bn2', l)
l = tf.nn.relu(l)
l = Conv2D('conv3', l, ch_out * 4, 1) l = Conv2D('conv3', l, ch_out * 4, 1)
return l + shortcut(input, ch_in, ch_out * 4, stride) return l + shortcut(input, ch_in, ch_out * 4, stride)
...@@ -102,8 +96,7 @@ class Model(ModelDesc): ...@@ -102,8 +96,7 @@ class Model(ModelDesc):
.apply(layer, 'group1', block_func, 128, defs[1], 2) .apply(layer, 'group1', block_func, 128, defs[1], 2)
.apply(layer, 'group2', block_func, 256, defs[2], 2) .apply(layer, 'group2', block_func, 256, defs[2], 2)
.apply(layer, 'group3', block_func, 512, defs[3], 2) .apply(layer, 'group3', block_func, 512, defs[3], 2)
.BatchNorm('bnlast') .BNReLU('bnlast')
.tf.nn.relu()
.GlobalAvgPooling('gap') .GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)()) .FullyConnected('linear', 1000, nl=tf.identity)())
......
...@@ -57,7 +57,7 @@ class LinearWrap(object): ...@@ -57,7 +57,7 @@ class LinearWrap(object):
return LinearWrap(ret) return LinearWrap(ret)
else: else:
def f(*args, **kwargs): def f(*args, **kwargs):
if isinstance(args[0], six.string_types): if len(args) and isinstance(args[0], six.string_types):
name, args = args[0], args[1:] name, args = args[0], args[1:]
ret = layer(name, self._t, *args, **kwargs) ret = layer(name, self._t, *args, **kwargs)
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment