Commit 0d20032a authored by Yuxin Wu's avatar Yuxin Wu

simplify resnet BNReLU

parent 5828a161
......@@ -44,7 +44,4 @@ The validation error here is computed on test set.
![cifar10](cifar10-resnet.png)
Download model:
[Cifar10 ResNet-110 (n=18)](https://drive.google.com/open?id=0B9IPQTvr2BBkTXBlZmh1cmlnQ0k)
Also see an implementation of [DenseNet](https://github.com/YixuanLi/densenet-tensorflow) from [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993).
......@@ -54,16 +54,9 @@ class Model(ModelDesc):
stride1 = 1
with tf.variable_scope(name) as scope:
if not first:
b1 = BatchNorm('bn1', l)
b1 = tf.nn.relu(b1)
else:
b1 = l
c1 = Conv2D('conv1', b1, out_channel, stride=stride1)
b2 = BatchNorm('bn2', c1)
b2 = tf.nn.relu(b2)
c2 = Conv2D('conv2', b2, out_channel)
b1 = l if first else BNReLU(l)
c1 = Conv2D('conv1', b1, out_channel, stride=stride1, nl=BNReLU)
c2 = Conv2D('conv2', c1, out_channel)
if increase_dim:
l = AvgPooling('pool', l, 2)
l = tf.pad(l, [[0,0], [0,0], [0,0], [in_channel//2, in_channel//2]])
......@@ -73,9 +66,7 @@ class Model(ModelDesc):
with argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3,
W_init=variance_scaling_initializer(mode='FAN_OUT')):
l = Conv2D('conv0', image, 16)
l = BatchNorm('bn0', l)
l = tf.nn.relu(l)
l = Conv2D('conv0', image, 16, nl=BNReLU)
l = residual('res1.0', l, first=True)
for k in range(1, self.n):
l = residual('res1.{}'.format(k), l)
......@@ -89,8 +80,7 @@ class Model(ModelDesc):
l = residual('res3.0', l, increase_dim=True)
for k in range(1, self.n):
l = residual('res3.' + str(k), l)
l = BatchNorm('bnlast', l)
l = tf.nn.relu(l)
l = BNReLU('bnlast', l)
# 8,c=64
l = GlobalAvgPooling('gap', l)
......
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: imagenet-resnet.py
# File: imagenet-resnet-short.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import cv2
......@@ -42,36 +42,30 @@ class Model(ModelDesc):
def basicblock(l, ch_out, stride, preact):
ch_in = l.get_shape().as_list()[-1]
input = l
if preact == 'both_preact':
l = BatchNorm('preact', l)
l = tf.nn.relu(l, name='preact-relu')
l = BNReLU('preact', l)
input = l
elif preact != 'no_preact':
l = BatchNorm('preact', l)
l = tf.nn.relu(l, name='preact-relu')
l = Conv2D('conv1', l, ch_out, 3, stride=stride)
l = BatchNorm('bn', l)
l = tf.nn.relu(l)
input = l
l = BNReLU('preact', l)
else:
input = l
l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3)
return l + shortcut(input, ch_in, ch_out, stride)
def bottleneck(l, ch_out, stride, preact):
ch_in = l.get_shape().as_list()[-1]
input = l
if preact == 'both_preact':
l = BatchNorm('preact', l)
l = tf.nn.relu(l, name='preact-relu')
l = BNReLU('preact', l)
input = l
elif preact != 'no_preact':
l = BatchNorm('preact', l)
l = tf.nn.relu(l, name='preact-relu')
l = Conv2D('conv1', l, ch_out, 1)
l = BatchNorm('bn1', l)
l = tf.nn.relu(l)
l = Conv2D('conv2', l, ch_out, 3, stride=stride)
l = BatchNorm('bn2', l)
l = tf.nn.relu(l)
input = l
l = BNReLU('preact', l)
else:
input = l
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1)
return l + shortcut(input, ch_in, ch_out * 4, stride)
......@@ -102,8 +96,7 @@ class Model(ModelDesc):
.apply(layer, 'group1', block_func, 128, defs[1], 2)
.apply(layer, 'group2', block_func, 256, defs[2], 2)
.apply(layer, 'group3', block_func, 512, defs[3], 2)
.BatchNorm('bnlast')
.tf.nn.relu()
.BNReLU('bnlast')
.GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)())
......
......@@ -57,7 +57,7 @@ class LinearWrap(object):
return LinearWrap(ret)
else:
def f(*args, **kwargs):
if isinstance(args[0], six.string_types):
if len(args) and isinstance(args[0], six.string_types):
name, args = args[0], args[1:]
ret = layer(name, self._t, *args, **kwargs)
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment