Commit a19a57b2 authored by Yuxin Wu's avatar Yuxin Wu

add note in bn and stn

parent 0ba0336c
......@@ -15,31 +15,37 @@ __all__ = ['BatchNorm']
@layer_register()
def BatchNorm(x, is_training):
"""
Batch normalization layer as described in:
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
http://arxiv.org/abs/1502.03167
Notes:
Whole-population mean/variance is calculated by a running-average mean/variance, with decay rate 0.999
Epsilon for variance is set to 1e-5, as is torch/nn: https://github.com/torch/nn/blob/master/BatchNormalization.lua
x: BHWC tensor
is_training: bool
"""
EPS = 1e-5
is_training = bool(is_training)
shape = x.get_shape().as_list()
assert len(shape) == 4
n_out = shape[-1] # channel
beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
gamma = tf.get_variable('gamma', [n_out], initializer=tf.constant_initializer(1.0))
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.9)
ema = tf.train.ExponentialMovingAverage(decay=0.999)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
if is_training:
def mean_var_with_update():
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = mean_var_with_update()
with tf.control_dependencies([ema_apply_op]):
mean, var = tf.identity(batch_mean), tf.identity(batch_var)
else:
mean, var = ema_mean, ema_var
batch = tf.cast(tf.shape(x)[0], tf.float32)
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
normed = tf.nn.batch_norm_with_global_normalization(x, mean, var, beta, gamma, 1e-4, True)
normed = tf.nn.batch_norm_with_global_normalization(x, mean, var, beta, gamma, EPS, True)
return normed
......@@ -42,6 +42,9 @@ def sample(img, coords):
def ImageSample(inputs):
"""
Sample the template image, using the given coordinate, by bilinear interpolation.
It mimics the same behavior described in:
Spatial Transformer Networks, http://arxiv.org/abs/1506.02025
inputs: list of [template, mapping]
template: bxhxwxc
mapping: bxh2xw2x2 (y, x) real-value coordinates
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment