Commit ebf1d570 authored by Yuxin Wu's avatar Yuxin Wu

fix BatchRenorm (fix #360)

parent 8b487b90
......@@ -6,8 +6,10 @@
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from tensorflow.python.layers.normalization import BatchNorm as TF_BatchNorm
from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import backup_collection, restore_collection
from ..utils import logger
from .common import layer_register, VariableHolder
......@@ -31,7 +33,7 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
initializer=tf.constant_initializer(1.0), trainable=False)
return beta, gamma, moving_mean, moving_var
......@@ -179,8 +181,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
* ``beta``: the bias term.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
* ``moving_mean, renorm_mean, renorm_mean_weight``: See TF documentation.
* ``moving_variance, renorm_stddev, renorm_stddev_weight``: See TF documentation.
"""
shape = x.get_shape().as_list()
......@@ -188,59 +190,44 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
assert ndims in [2, 4]
if ndims == 2:
data_format = 'NHWC' # error using NCHW? (see #190)
x = tf.reshape(x, [-1, 1, 1, shape[1]])
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchRenorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(
n_out, use_scale, use_bias, tf.constant_initializer(1.0))
ctx = get_current_tower_context()
use_local_stat = ctx.is_training
# for BatchRenorm, use_local_stat should always be is_training, unless a
# different usage comes out in the future.
if use_local_stat:
if ndims == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
x, gamma, beta, epsilon=epsilon, is_training=True, data_format=data_format)
inv_sigma = tf.rsqrt(moving_var, 'inv_sigma')
r = tf.stop_gradient(tf.clip_by_value(
tf.sqrt(batch_var) * inv_sigma, 1.0 / rmax, rmax))
d = tf.stop_gradient(tf.clip_by_value(
(batch_mean - moving_mean) * inv_sigma,
-dmax, dmax))
r = reshape_for_bn(r, ndims, n_out, data_format)
d = reshape_for_bn(d, ndims, n_out, data_format)
xn = xn * r + d
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
layer = TF_BatchNorm(
axis=1 if data_format == 'NCHW' else 3,
momentum=decay, epsilon=epsilon,
center=use_bias, scale=use_scale,
renorm=True,
renorm_clipping={
'rmin': 1.0 / rmax,
'rmax': rmax,
'dmax': dmax},
renorm_momentum=0.99,
fused=False)
xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope())
else:
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]]
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
else:
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
# training also needs EMA, so we should maintain it as long as there are
# corresponding EMA variables.
if ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
# only apply update in this case
for v in layer.non_trainable_variables:
add_model_variable(v)
else:
ret = tf.identity(xn, name='output')
# don't need update if we are sharing variables from an old tower
restore_collection(coll_bk)
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
ret = tf.identity(xn, name='output')
# TODO not sure whether to add moving_mean/moving_var to VH now
vh = ret.variables = VariableHolder()
if use_scale:
vh.gamma = gamma
vh.gamma = layer.gamma
if use_bias:
vh.beta = beta
vh.beta = layer.beta
return ret
......@@ -22,6 +22,7 @@ def backup_collection(keys):
dict: the backup
"""
ret = {}
assert isinstance(keys, (list, tuple))
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment