Commit f6ede612 authored by Yuxin Wu's avatar Yuxin Wu

Better BatchNorm (with ema_update option decoupled from training)

parent 4a46b93d
......@@ -169,8 +169,8 @@ class SeparateGANTrainer(TowerTrainer):
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
with TowerContext('', is_training=True), \
argscope(BatchNorm, internal_update=True):
# should not hook the updates to both train_op, it will hurt training speed.
argscope(BatchNorm, ema_update='internal'):
# should not hook the EMA updates to both train_op, it will hurt training speed.
self.tower_func(*input.get_input_tensors())
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if len(update_ops):
......
......@@ -12,6 +12,7 @@ from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import get_data_format
from ..utils.develop import log_deprecated
from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable
......@@ -39,8 +40,8 @@ def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init):
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay):
def internal_update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay):
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
......@@ -71,8 +72,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
gamma_initializer=tf.ones_initializer(),
virtual_batch_size=None,
data_format='channels_last',
internal_update=False,
sync_statistics=None):
ema_update='default',
sync_statistics=None,
internal_update=None):
"""
Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
in the following:
......@@ -80,21 +82,29 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
4. Support the ``internal_update`` option, which cover more use cases than the standard collection-based update.
5. Support the ``sync_statistics`` option, which is very useful in small-batch models.
4. Support the ``ema_update`` option, which cover more use cases than the standard EMA update.
5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models.
Args:
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies.
They are very similar in speed, but `internal_update=True` is recommended and can be helpful when:
1. BatchNorm is used inside dynamic control flow.
The collection-based update does not support dynamic control flows.
2. BatchNorm layer is sometimes unused (e.g., when you have two networks to train alternatively).
Putting all update ops into a single collection will waste a lot of compute.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
sync_statistics (str or None): one of None, "nccl", or "horovod".
training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
to normalize. By default, it is equal to `get_current_tower_context().is_training`.
This is not a good argument name, but it is what the Tensorflow layer uses.
ema_update (str): Only effective when ``training=True``. It has the following options:
* "default": same as "collection". Because this is the default behavior in tensorflow.
* "skip": do not update EMA.
* "collection": Add EMA update ops to collection `tf.GraphKeys.UPDATE_OPS`.
The ops in the collection will be run automatically by the callback :class:`RunUpdateOps`.
* "internal": EMA is updated inside this layer itself by control dependencies.
It has similar speed to "collection", but "internal" is recommended and can be helpful when:
1. BatchNorm is used inside dynamic control flow.
The collection-based update does not support dynamic control flows.
2. BatchNorm layer is sometimes unused (e.g., when you have two networks to train alternatively).
Putting all update ops into a single collection will waste a lot of compute.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
sync_statistics (str or None): one of None, "nccl", or "horovod". It determines how to compute the
"per-batch statistics" when ``training==True``.
By default (None), it uses statistics of the input tensor to normalize during training.
This is the standard way BatchNorm was implemented in most frameworks.
......@@ -119,15 +129,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
If different GPUs execute one BatchNorm layer for different number of times
(e.g., if some GPUs do not execute it), this layer may hang.
This option only has effect when `training == get_current_tower_context().training == True`.
This option is also known as "Cross-GPU BatchNorm" as mentioned in:
This option is also known as "SyncBN" or Cross-GPU BatchNorm" as mentioned in:
`MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222.
When `sync_statistics` is enabled, `internal_update` will be set to True automatically.
When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically.
This is to avoid running `UPDATE_OPS`, which requires synchronization.
internal_update: deprecated option. Don't use.
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
......@@ -136,16 +146,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
* ``variance/EMA``: the moving average of variance.
Note:
Combinations of ``training`` and ``ctx.is_training``:
* ``training == ctx.is_training``: standard BN, EMA are maintained during training
and used during inference. This is the default.
* ``training and not ctx.is_training``: still use batch statistics in inference.
* ``not training and ctx.is_training``: use EMA to normalize in
training. This is useful when you load a pre-trained BN and
don't want to fine tune the EMA. EMA will not be updated in
this case.
This layer is more flexible than the standard "BatchNorm" layer and provides more features:
1. No matter whether you're doing training or not, you can set the `training` argument
to use batch statistics / EMA statistics.
i.e., you can use batch statistics during inference, or use EMA statistics during training.
Using EMA statistics in training is useful when you load a pre-trained BN and
don't want to update it.
2. As long as `training=True`, `sync_statistics` and `ema_update` option will take effect.
"""
ctx = get_current_tower_context()
# parse shapes
data_format = get_data_format(data_format, keras_mode=False)
shape = inputs.get_shape().as_list()
......@@ -155,6 +164,23 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
sync_statistics = sync_statistics.lower()
assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics
assert ema_update in ["default", "collection", "internal", "skip"]
if internal_update is not None:
log_deprecated("BatchNorm(internal_update=)", "Use ema_update='internal' instead!", "2020-01-01")
assert ema_update == 'default', \
"Do not use internal_update and ema_update together! internal_update is deprecated"
ema_update = "internal" if internal_update else "collection"
if ema_update == "default":
ema_update = "collection"
# Logic:
# 1. EMA update is possible only when we compute batch statistics (training=True)
# 2. We know that in training, non-main training tower does not need EMA update
# We don't know about what to do in prediction context, so be conservative and do the update.
# 3. User and explicit disable update by "skip".
do_ema_update = training and \
(ctx.is_main_training_tower or not ctx.is_training) \
and (ema_update != "skip")
if axis is None:
if ndims == 2:
axis = 1
......@@ -163,12 +189,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
assert axis in [1, 3], axis
num_chan = shape[axis]
TF_version = get_tf_version_tuple()
# parse training/ctx
ctx = get_current_tower_context()
if training is None:
training = ctx.is_training
training = bool(training)
TF_version = get_tf_version_tuple()
freeze_bn_backward = not training and ctx.is_training
if freeze_bn_backward:
assert TF_version >= (1, 4), \
......@@ -177,12 +203,14 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
do_sync_bn = (sync_statistics is not None) and training
if sync_statistics is None or not (training and ctx.is_training):
if not do_sync_bn:
# Use the builtin layer for anything except for sync-bn
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
with rename_get_variable(
{'moving_mean': 'mean/EMA',
'moving_variance': 'variance/EMA'}):
'moving_variance': 'variance/EMA'}):
tf_args = dict(
axis=axis,
momentum=momentum, epsilon=epsilon,
......@@ -204,16 +232,17 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
layer = tf.layers.BatchNormalization(**tf_args)
xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
# maintain EMA only on one GPU is OK, even in replicated mode.
# because during training, EMA isn't used
# Add EMA variables to the correct collection
if ctx.is_main_training_tower:
for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable):
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if not ctx.is_main_training_tower or internal_update:
restore_collection(coll_bk)
if training and internal_update:
if not do_ema_update:
restore_collection(coll_bk)
if do_ema_update and ema_update == "internal":
# Implement "internal" update.
restore_collection(coll_bk)
assert layer.updates
with tf.control_dependencies(layer.updates):
ret = tf.identity(xn, name='output')
......@@ -301,8 +330,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
inputs, batch_mean, batch_var,
beta, gamma, epsilon)
if ctx.is_main_training_tower:
ret = update_bn_ema(
if do_ema_update:
ret = internal_update_bn_ema(
xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var, momentum)
else:
ret = tf.identity(xn, name='output')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment