Commit 12d27154 authored by Yuxin Wu's avatar Yuxin Wu

use unbiased variance in training

parent 132dcccd
......@@ -30,8 +30,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
* Epsilon for variance is set to 1e-5, as is `torch/nn <https://github.com/torch/nn/blob/master/BatchNormalization.lua>`_.
:param input: a NHWC tensor or a NC vector
:param use_local_stat: bool. whether to use mean/var of this batch or the running average.
Usually set to True in training and False in testing
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average. Set to True in training and False in testing
:param decay: decay rate. default to 0.999.
:param epsilon: default to 1e-5.
"""
......@@ -40,9 +39,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
assert len(shape) in [2, 4]
n_out = shape[-1] # channel
assert n_out is not None
beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable(
'gamma', [n_out],
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
if len(shape) == 2:
......@@ -51,7 +50,8 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
emaname = 'EMA'
if not batch_mean.name.startswith('towerp'):
in_train_tower = not batch_mean.name.startswith('towerp')
if in_train_tower:
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
......@@ -65,6 +65,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
G = tf.get_default_graph()
# find training statistics in training tower
try:
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name)
......@@ -81,11 +82,11 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
if use_local_stat:
with tf.control_dependencies([ema_apply_op]):
batch = tf.cast(tf.shape(x)[0], tf.float32)
mul = tf.select(tf.equal(batch, 1.0), 1.0, batch / (batch - 1))
batch_var = batch_var * mul # use unbiased variance estimator in training
return tf.nn.batch_normalization(
x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
else:
batch = tf.cast(tf.shape(x)[0], tf.float32)
# XXX TODO batch==1?
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_normalization(
x, mean, var, beta, gamma, epsilon, 'bn')
x, ema_mean, ema_var, beta, gamma, epsilon, 'bn')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment