Commit ab281e81 authored by Yuxin Wu's avatar Yuxin Wu

add batch renorm

parent 2b095948
......@@ -11,13 +11,13 @@ from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from .common import layer_register
__all__ = ['BatchNorm']
__all__ = ['BatchNorm', 'BatchRenorm']
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
# Deprecated. Only kept for future reference.
# XXX This is deprecated. Only kept for future reference.
@layer_register(log_shape=False)
def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
shape = x.get_shape().as_list()
......@@ -96,10 +96,51 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
def get_bn_variables(x, use_scale, use_bias):
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else:
beta = tf.zeros([n_out], name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=tf.constant_initializer(1.0))
else:
gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
# TODO update it later (similar to slim) might be faster?
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
add_model_variable(moving_mean)
add_model_variable(moving_var)
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
@layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True):
"""
Batch normalization layer, as described in the paper:
Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
......@@ -109,6 +150,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
Defaults to True in training and False in inference.
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x.
......@@ -121,53 +163,29 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
* ``variance/EMA``: the moving average of variance.
Note:
In multi-tower training, only the first training tower maintains a moving average.
In multi-GPU training, moving averages across GPUs are not aggregated.
This is consistent with most frameworks.
However, all GPUs use the moving averages on the first GPU (instead of
their own), this is inconsistent with most frameworks (but consistent
with the official inceptionv3 example).
"""
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
beta = tf.get_variable('beta', [n_out],
initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
# x * gamma + beta
beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
ctx = get_current_tower_context()
if use_local_stat is None:
use_local_stat = ctx.is_training
if use_local_stat != ctx.is_training:
# we allow the use of local_stat in testing (only print warnings)
# because it is useful to certain applications.
logger.warn("[BatchNorm] use_local_stat != is_training")
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True)
# maintain EMA only in the main training tower
if ctx.is_main_training_tower:
# TODO a way to use debias in multitower.
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
add_model_variable(moving_mean)
add_model_variable(moving_var)
else:
assert not ctx.is_training, "In training, local statistics has to be used!"
# TODO do I need to add_model_variable.
# consider some fixed-param tasks, such as load model and fine tune one layer
# fused seems slower in inference
# xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
# moving_mean, moving_var,
......@@ -178,12 +196,65 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if len(shape) == 2:
xn = tf.squeeze(xn, [1, 2])
# TODO for other towers, maybe can make it depend some op later
# TODO update it later (similar to slim) might be faster?
# TODO main tower already has too many work, would it be faster to update
# it only on the last tower?
# maintain EMA only on one GPU.
# TODO the first GPU already has too many work, might be faster to update it on a different GPU
if ctx.is_main_training_tower:
with tf.control_dependencies([update_op1, update_op2]):
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
return tf.identity(xn, name='output')
@layer_register(log_shape=False)
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True):
"""
Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
<https://arxiv.org/abs/1702.03275>`_.
Args:
x (tf.Tensor): a NHWC or NC tensor.
rmax, dmax (tf.Tensor): a scalar tensor, the maximum allowed corrections.
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x.
Variable Names:
* ``beta``: the bias term.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
"""
shape = x.get_shape().as_list()
beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
ctx = get_current_tower_context()
use_local_stat = ctx.is_training
# for BatchRenorm, use_local_stat should always be is_training, unless a
# different usage comes out in the future.
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True)
moving_sigma = tf.sqrt(moving_var, 'sigma')
r = tf.stop_gradient(tf.clip_by_value(
tf.sqrt(batch_var / moving_var), 1.0 / rmax, rmax))
d = tf.stop_gradient(tf.clip_by_value(
(batch_mean - moving_mean) / moving_sigma,
-dmax, dmax))
xn = xn * r + d
else:
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if len(shape) == 2:
xn = tf.squeeze(xn, [1, 2])
if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
return tf.identity(xn, name='output')
......@@ -8,7 +8,7 @@ from contextlib import contextmanager
from .gradproc import apply_grad_processors as apply_gradproc
__all__ = ['apply_grad_processors', 'ProxyOptimizer',
'PostProcessVariablesOptimizer']
'PostProcessOptimizer', 'VariableAssignmentOptimizer']
class ProxyOptimizer(tf.train.Optimizer):
......@@ -56,10 +56,10 @@ def apply_grad_processors(opt, gradprocs):
return _ApplyGradientProcessor(opt, gradprocs)
class PostProcessVariablesOptimizer(ProxyOptimizer):
class PostProcessOptimizer(ProxyOptimizer):
"""
An optimizer which applies an operation to variables
(e.g. clipping, quantization) after updating the gradient.
An optimizer which applies some "post-processing operation" per variable
(e.g. clipping, quantization) after the gradient update.
"""
def __init__(self, opt, func, colocate=True):
"""
......@@ -69,12 +69,12 @@ class PostProcessVariablesOptimizer(ProxyOptimizer):
to perform for this variable after the gradient update.
colocate (boolean): colocate the function with the variable.
"""
super(PostProcessVariablesOptimizer, self).__init__(opt)
super(PostProcessOptimizer, self).__init__(opt)
self._func = func
self._colocate = colocate
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_op = super(PostProcessVariablesOptimizer, self).apply_gradients(
update_op = super(PostProcessOptimizer, self).apply_gradients(
grads_and_vars, global_step)
ops = []
with tf.control_dependencies([update_op]):
......@@ -95,3 +95,23 @@ class PostProcessVariablesOptimizer(ProxyOptimizer):
yield
else:
yield
class VariableAssignmentOptimizer(PostProcessOptimizer):
"""
An optimizer which assigns each variable a new value (e.g. clipping,
quantization) after the gradient update.
"""
def __init__(self, opt, func):
"""
Args:
opt (tf.train.Optimizer):
func (tf.Variable -> tf.Tensor or None): the new value to be
assigned to this variable after the gradient update.
"""
def f(v):
t = func(v)
if t is None:
return t
return tf.assign(v, t, use_locking=False).op
super(VariableAssignmentOptimizer, self).__init__(opt, f)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment