Commit ebf1d570 authored by Yuxin Wu's avatar Yuxin Wu

fix BatchRenorm (fix #360)

parent 8b487b90
...@@ -6,8 +6,10 @@ ...@@ -6,8 +6,10 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from tensorflow.python.layers.normalization import BatchNorm as TF_BatchNorm
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import backup_collection, restore_collection
from ..utils import logger from ..utils import logger
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
...@@ -31,7 +33,7 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init): ...@@ -31,7 +33,7 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
moving_mean = tf.get_variable('mean/EMA', [n_out], moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out], moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(1.0), trainable=False)
return beta, gamma, moving_mean, moving_var return beta, gamma, moving_mean, moving_var
...@@ -179,8 +181,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -179,8 +181,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
* ``beta``: the bias term. * ``beta``: the bias term.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``. * ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean. * ``moving_mean, renorm_mean, renorm_mean_weight``: See TF documentation.
* ``variance/EMA``: the moving average of variance. * ``moving_variance, renorm_stddev, renorm_stddev_weight``: See TF documentation.
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
...@@ -188,59 +190,44 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -188,59 +190,44 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
assert ndims in [2, 4] assert ndims in [2, 4]
if ndims == 2: if ndims == 2:
data_format = 'NHWC' # error using NCHW? (see #190) data_format = 'NHWC' # error using NCHW? (see #190)
x = tf.reshape(x, [-1, 1, 1, shape[1]])
if data_format == 'NCHW': if data_format == 'NCHW':
n_out = shape[1] n_out = shape[1]
else: else:
n_out = shape[-1] # channel n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchRenorm cannot have unknown channels!" assert n_out is not None, "Input to BatchRenorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(
n_out, use_scale, use_bias, tf.constant_initializer(1.0))
ctx = get_current_tower_context() ctx = get_current_tower_context()
use_local_stat = ctx.is_training coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
# for BatchRenorm, use_local_stat should always be is_training, unless a layer = TF_BatchNorm(
# different usage comes out in the future. axis=1 if data_format == 'NCHW' else 3,
momentum=decay, epsilon=epsilon,
if use_local_stat: center=use_bias, scale=use_scale,
if ndims == 2: renorm=True,
x = tf.reshape(x, [-1, 1, 1, n_out]) renorm_clipping={
'rmin': 1.0 / rmax,
xn, batch_mean, batch_var = tf.nn.fused_batch_norm( 'rmax': rmax,
x, gamma, beta, epsilon=epsilon, is_training=True, data_format=data_format) 'dmax': dmax},
renorm_momentum=0.99,
inv_sigma = tf.rsqrt(moving_var, 'inv_sigma') fused=False)
r = tf.stop_gradient(tf.clip_by_value( xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope())
tf.sqrt(batch_var) * inv_sigma, 1.0 / rmax, rmax))
d = tf.stop_gradient(tf.clip_by_value(
(batch_mean - moving_mean) * inv_sigma,
-dmax, dmax))
r = reshape_for_bn(r, ndims, n_out, data_format)
d = reshape_for_bn(d, ndims, n_out, data_format)
xn = xn * r + d
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]]
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
else:
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
# training also needs EMA, so we should maintain it as long as there are
# corresponding EMA variables.
if ctx.has_own_variables: if ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) # only apply update in this case
for v in layer.non_trainable_variables:
add_model_variable(v)
else: else:
ret = tf.identity(xn, name='output') # don't need update if we are sharing variables from an old tower
restore_collection(coll_bk)
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var) if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
ret = tf.identity(xn, name='output')
# TODO not sure whether to add moving_mean/moving_var to VH now
vh = ret.variables = VariableHolder()
if use_scale: if use_scale:
vh.gamma = gamma vh.gamma = layer.gamma
if use_bias: if use_bias:
vh.beta = beta vh.beta = layer.beta
return ret return ret
...@@ -22,6 +22,7 @@ def backup_collection(keys): ...@@ -22,6 +22,7 @@ def backup_collection(keys):
dict: the backup dict: the backup
""" """
ret = {} ret = {}
assert isinstance(keys, (list, tuple))
for k in keys: for k in keys:
ret[k] = copy(tf.get_collection(k)) ret[k] = copy(tf.get_collection(k))
return ret return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment