Commit ab281e81 authored by Yuxin Wu's avatar Yuxin Wu

add batch renorm

parent 2b095948
...@@ -11,13 +11,13 @@ from ..tfutils.tower import get_current_tower_context ...@@ -11,13 +11,13 @@ from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from .common import layer_register from .common import layer_register
__all__ = ['BatchNorm'] __all__ = ['BatchNorm', 'BatchRenorm']
# decay: being too close to 1 leads to slow start-up. torch use 0.9. # decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4 # eps: torch: 1e-5. Lasagne: 1e-4
# Deprecated. Only kept for future reference. # XXX This is deprecated. Only kept for future reference.
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
...@@ -96,10 +96,51 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -96,10 +96,51 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x, ema_mean, ema_var, beta, gamma, epsilon, 'output') x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
def get_bn_variables(x, use_scale, use_bias):
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else:
beta = tf.zeros([n_out], name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=tf.constant_initializer(1.0))
else:
gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
# TODO update it later (similar to slim) might be faster?
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
add_model_variable(moving_mean)
add_model_variable(moving_var)
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True):
""" """
Batch normalization layer, as described in the paper: Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by `Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_. Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
...@@ -109,6 +150,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -109,6 +150,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
Defaults to True in training and False in inference. Defaults to True in training and False in inference.
decay (float): decay rate of moving average. decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero. epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
Returns: Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x. tf.Tensor: a tensor named ``output`` with the same shape of x.
...@@ -121,53 +163,29 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -121,53 +163,29 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
* ``variance/EMA``: the moving average of variance. * ``variance/EMA``: the moving average of variance.
Note: Note:
In multi-tower training, only the first training tower maintains a moving average. In multi-GPU training, moving averages across GPUs are not aggregated.
This is consistent with most frameworks. This is consistent with most frameworks.
However, all GPUs use the moving averages on the first GPU (instead of
their own), this is inconsistent with most frameworks (but consistent
with the official inceptionv3 example).
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) in [2, 4] beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
beta = tf.get_variable('beta', [n_out],
initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
# x * gamma + beta
ctx = get_current_tower_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
use_local_stat = ctx.is_training use_local_stat = ctx.is_training
if use_local_stat != ctx.is_training: if use_local_stat != ctx.is_training:
# we allow the use of local_stat in testing (only print warnings)
# because it is useful to certain applications.
logger.warn("[BatchNorm] use_local_stat != is_training") logger.warn("[BatchNorm] use_local_stat != is_training")
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
if use_local_stat: if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta, xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True) epsilon=epsilon, is_training=True)
# maintain EMA only in the main training tower
if ctx.is_main_training_tower:
# TODO a way to use debias in multitower.
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
add_model_variable(moving_mean)
add_model_variable(moving_var)
else: else:
assert not ctx.is_training, "In training, local statistics has to be used!" assert not ctx.is_training, "In training, local statistics has to be used!"
# TODO do I need to add_model_variable.
# consider some fixed-param tasks, such as load model and fine tune one layer
# fused seems slower in inference # fused seems slower in inference
# xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta, # xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
# moving_mean, moving_var, # moving_mean, moving_var,
...@@ -178,12 +196,65 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -178,12 +196,65 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if len(shape) == 2: if len(shape) == 2:
xn = tf.squeeze(xn, [1, 2]) xn = tf.squeeze(xn, [1, 2])
# TODO for other towers, maybe can make it depend some op later # maintain EMA only on one GPU.
# TODO update it later (similar to slim) might be faster? # TODO the first GPU already has too many work, might be faster to update it on a different GPU
# TODO main tower already has too many work, would it be faster to update if ctx.is_main_training_tower:
# it only on the last tower? return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
return tf.identity(xn, name='output')
@layer_register(log_shape=False)
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True):
"""
Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
<https://arxiv.org/abs/1702.03275>`_.
Args:
x (tf.Tensor): a NHWC or NC tensor.
rmax, dmax (tf.Tensor): a scalar tensor, the maximum allowed corrections.
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x.
Variable Names:
* ``beta``: the bias term.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
"""
shape = x.get_shape().as_list()
beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
ctx = get_current_tower_context()
use_local_stat = ctx.is_training
# for BatchRenorm, use_local_stat should always be is_training, unless a
# different usage comes out in the future.
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True)
moving_sigma = tf.sqrt(moving_var, 'sigma')
r = tf.stop_gradient(tf.clip_by_value(
tf.sqrt(batch_var / moving_var), 1.0 / rmax, rmax))
d = tf.stop_gradient(tf.clip_by_value(
(batch_mean - moving_mean) / moving_sigma,
-dmax, dmax))
xn = xn * r + d
else:
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if len(shape) == 2:
xn = tf.squeeze(xn, [1, 2])
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
with tf.control_dependencies([update_op1, update_op2]): return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
return tf.identity(xn, name='output')
else: else:
return tf.identity(xn, name='output') return tf.identity(xn, name='output')
...@@ -8,7 +8,7 @@ from contextlib import contextmanager ...@@ -8,7 +8,7 @@ from contextlib import contextmanager
from .gradproc import apply_grad_processors as apply_gradproc from .gradproc import apply_grad_processors as apply_gradproc
__all__ = ['apply_grad_processors', 'ProxyOptimizer', __all__ = ['apply_grad_processors', 'ProxyOptimizer',
'PostProcessVariablesOptimizer'] 'PostProcessOptimizer', 'VariableAssignmentOptimizer']
class ProxyOptimizer(tf.train.Optimizer): class ProxyOptimizer(tf.train.Optimizer):
...@@ -56,10 +56,10 @@ def apply_grad_processors(opt, gradprocs): ...@@ -56,10 +56,10 @@ def apply_grad_processors(opt, gradprocs):
return _ApplyGradientProcessor(opt, gradprocs) return _ApplyGradientProcessor(opt, gradprocs)
class PostProcessVariablesOptimizer(ProxyOptimizer): class PostProcessOptimizer(ProxyOptimizer):
""" """
An optimizer which applies an operation to variables An optimizer which applies some "post-processing operation" per variable
(e.g. clipping, quantization) after updating the gradient. (e.g. clipping, quantization) after the gradient update.
""" """
def __init__(self, opt, func, colocate=True): def __init__(self, opt, func, colocate=True):
""" """
...@@ -69,12 +69,12 @@ class PostProcessVariablesOptimizer(ProxyOptimizer): ...@@ -69,12 +69,12 @@ class PostProcessVariablesOptimizer(ProxyOptimizer):
to perform for this variable after the gradient update. to perform for this variable after the gradient update.
colocate (boolean): colocate the function with the variable. colocate (boolean): colocate the function with the variable.
""" """
super(PostProcessVariablesOptimizer, self).__init__(opt) super(PostProcessOptimizer, self).__init__(opt)
self._func = func self._func = func
self._colocate = colocate self._colocate = colocate
def apply_gradients(self, grads_and_vars, global_step=None, name=None): def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_op = super(PostProcessVariablesOptimizer, self).apply_gradients( update_op = super(PostProcessOptimizer, self).apply_gradients(
grads_and_vars, global_step) grads_and_vars, global_step)
ops = [] ops = []
with tf.control_dependencies([update_op]): with tf.control_dependencies([update_op]):
...@@ -95,3 +95,23 @@ class PostProcessVariablesOptimizer(ProxyOptimizer): ...@@ -95,3 +95,23 @@ class PostProcessVariablesOptimizer(ProxyOptimizer):
yield yield
else: else:
yield yield
class VariableAssignmentOptimizer(PostProcessOptimizer):
"""
An optimizer which assigns each variable a new value (e.g. clipping,
quantization) after the gradient update.
"""
def __init__(self, opt, func):
"""
Args:
opt (tf.train.Optimizer):
func (tf.Variable -> tf.Tensor or None): the new value to be
assigned to this variable after the gradient update.
"""
def f(v):
t = func(v)
if t is None:
return t
return tf.assign(v, t, use_locking=False).op
super(VariableAssignmentOptimizer, self).__init__(opt, f)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment