Commit fec3a4a5 authored by Yuxin Wu's avatar Yuxin Wu

add BNReLU as a nonlin

parent 7d9582a1
......@@ -46,21 +46,15 @@ class Model(ModelDesc):
image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=tf.identity)
l = BatchNorm('bn1', l, is_training)
l = tf.nn.relu(l)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training))
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')
l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=tf.identity)
l = BatchNorm('bn2', l, is_training)
l = tf.nn.relu(l)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training))
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')
l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3, padding='VALID')
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=tf.identity)
l = BatchNorm('bn3', l, is_training)
l = tf.nn.relu(l)
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training))
l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1))
l = tf.nn.dropout(l, keep_prob)
......
......@@ -7,8 +7,9 @@ import tensorflow as tf
from copy import copy
from ._common import *
from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU']
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
@layer_register()
def Maxout(x, num_unit):
......@@ -59,3 +60,15 @@ def LeakyReLU(x, alpha, name=None):
return x * 0.5
else:
return tf.mul(x, 0.5, name=name)
def BNReLU(is_training):
"""
:returns: a activation function that performs BN + ReLU (a too common combination)
"""
def f(x, name=None):
with tf.variable_scope('bn'):
x = BatchNorm.f(x, is_training)
x = tf.nn.relu(x, name=name)
return x
return f
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment