Commit f6ede612 authored by Yuxin Wu's avatar Yuxin Wu

Better BatchNorm (with ema_update option decoupled from training)

parent 4a46b93d
...@@ -169,8 +169,8 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -169,8 +169,8 @@ class SeparateGANTrainer(TowerTrainer):
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature()) self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
with TowerContext('', is_training=True), \ with TowerContext('', is_training=True), \
argscope(BatchNorm, internal_update=True): argscope(BatchNorm, ema_update='internal'):
# should not hook the updates to both train_op, it will hurt training speed. # should not hook the EMA updates to both train_op, it will hurt training speed.
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if len(update_ops): if len(update_ops):
......
...@@ -12,6 +12,7 @@ from ..tfutils.common import get_tf_version_tuple ...@@ -12,6 +12,7 @@ from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from ..utils.develop import log_deprecated
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable from .tflayer import convert_to_tflayer_args, rename_get_variable
...@@ -39,8 +40,8 @@ def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init): ...@@ -39,8 +40,8 @@ def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init):
return beta, gamma, moving_mean, moving_var return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var, def internal_update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay): moving_mean, moving_var, decay):
update_op1 = moving_averages.assign_moving_average( update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False, moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op') name='mean_ema_op')
...@@ -71,8 +72,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -71,8 +72,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
gamma_initializer=tf.ones_initializer(), gamma_initializer=tf.ones_initializer(),
virtual_batch_size=None, virtual_batch_size=None,
data_format='channels_last', data_format='channels_last',
internal_update=False, ema_update='default',
sync_statistics=None): sync_statistics=None,
internal_update=None):
""" """
Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful) Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
in the following: in the following:
...@@ -80,21 +82,29 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -80,21 +82,29 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored. 1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different. 2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten. 3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
4. Support the ``internal_update`` option, which cover more use cases than the standard collection-based update. 4. Support the ``ema_update`` option, which cover more use cases than the standard EMA update.
5. Support the ``sync_statistics`` option, which is very useful in small-batch models. 5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models.
Args: Args:
internal_update (bool): if False, add EMA update ops to training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies. to normalize. By default, it is equal to `get_current_tower_context().is_training`.
They are very similar in speed, but `internal_update=True` is recommended and can be helpful when: This is not a good argument name, but it is what the Tensorflow layer uses.
ema_update (str): Only effective when ``training=True``. It has the following options:
1. BatchNorm is used inside dynamic control flow. * "default": same as "collection". Because this is the default behavior in tensorflow.
The collection-based update does not support dynamic control flows. * "skip": do not update EMA.
2. BatchNorm layer is sometimes unused (e.g., when you have two networks to train alternatively). * "collection": Add EMA update ops to collection `tf.GraphKeys.UPDATE_OPS`.
Putting all update ops into a single collection will waste a lot of compute. The ops in the collection will be run automatically by the callback :class:`RunUpdateOps`.
* "internal": EMA is updated inside this layer itself by control dependencies.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699 It has similar speed to "collection", but "internal" is recommended and can be helpful when:
sync_statistics (str or None): one of None, "nccl", or "horovod".
1. BatchNorm is used inside dynamic control flow.
The collection-based update does not support dynamic control flows.
2. BatchNorm layer is sometimes unused (e.g., when you have two networks to train alternatively).
Putting all update ops into a single collection will waste a lot of compute.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
sync_statistics (str or None): one of None, "nccl", or "horovod". It determines how to compute the
"per-batch statistics" when ``training==True``.
By default (None), it uses statistics of the input tensor to normalize during training. By default (None), it uses statistics of the input tensor to normalize during training.
This is the standard way BatchNorm was implemented in most frameworks. This is the standard way BatchNorm was implemented in most frameworks.
...@@ -119,15 +129,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -119,15 +129,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
If different GPUs execute one BatchNorm layer for different number of times If different GPUs execute one BatchNorm layer for different number of times
(e.g., if some GPUs do not execute it), this layer may hang. (e.g., if some GPUs do not execute it), this layer may hang.
This option only has effect when `training == get_current_tower_context().training == True`. This option is also known as "SyncBN" or Cross-GPU BatchNorm" as mentioned in:
This option is also known as "Cross-GPU BatchNorm" as mentioned in:
`MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_. `MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222. Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222.
When `sync_statistics` is enabled, `internal_update` will be set to True automatically. When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically.
This is to avoid running `UPDATE_OPS`, which requires synchronization. This is to avoid running `UPDATE_OPS`, which requires synchronization.
internal_update: deprecated option. Don't use.
Variable Names: Variable Names:
* ``beta``: the bias term. Will be zero-inited by default. * ``beta``: the bias term. Will be zero-inited by default.
...@@ -136,16 +146,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -136,16 +146,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
* ``variance/EMA``: the moving average of variance. * ``variance/EMA``: the moving average of variance.
Note: Note:
Combinations of ``training`` and ``ctx.is_training``: This layer is more flexible than the standard "BatchNorm" layer and provides more features:
1. No matter whether you're doing training or not, you can set the `training` argument
* ``training == ctx.is_training``: standard BN, EMA are maintained during training to use batch statistics / EMA statistics.
and used during inference. This is the default. i.e., you can use batch statistics during inference, or use EMA statistics during training.
* ``training and not ctx.is_training``: still use batch statistics in inference. Using EMA statistics in training is useful when you load a pre-trained BN and
* ``not training and ctx.is_training``: use EMA to normalize in don't want to update it.
training. This is useful when you load a pre-trained BN and 2. As long as `training=True`, `sync_statistics` and `ema_update` option will take effect.
don't want to fine tune the EMA. EMA will not be updated in
this case.
""" """
ctx = get_current_tower_context()
# parse shapes # parse shapes
data_format = get_data_format(data_format, keras_mode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
...@@ -155,6 +164,23 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -155,6 +164,23 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
sync_statistics = sync_statistics.lower() sync_statistics = sync_statistics.lower()
assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics
assert ema_update in ["default", "collection", "internal", "skip"]
if internal_update is not None:
log_deprecated("BatchNorm(internal_update=)", "Use ema_update='internal' instead!", "2020-01-01")
assert ema_update == 'default', \
"Do not use internal_update and ema_update together! internal_update is deprecated"
ema_update = "internal" if internal_update else "collection"
if ema_update == "default":
ema_update = "collection"
# Logic:
# 1. EMA update is possible only when we compute batch statistics (training=True)
# 2. We know that in training, non-main training tower does not need EMA update
# We don't know about what to do in prediction context, so be conservative and do the update.
# 3. User and explicit disable update by "skip".
do_ema_update = training and \
(ctx.is_main_training_tower or not ctx.is_training) \
and (ema_update != "skip")
if axis is None: if axis is None:
if ndims == 2: if ndims == 2:
axis = 1 axis = 1
...@@ -163,12 +189,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -163,12 +189,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
assert axis in [1, 3], axis assert axis in [1, 3], axis
num_chan = shape[axis] num_chan = shape[axis]
TF_version = get_tf_version_tuple()
# parse training/ctx # parse training/ctx
ctx = get_current_tower_context()
if training is None: if training is None:
training = ctx.is_training training = ctx.is_training
training = bool(training) training = bool(training)
TF_version = get_tf_version_tuple()
freeze_bn_backward = not training and ctx.is_training freeze_bn_backward = not training and ctx.is_training
if freeze_bn_backward: if freeze_bn_backward:
assert TF_version >= (1, 4), \ assert TF_version >= (1, 4), \
...@@ -177,12 +203,14 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -177,12 +203,14 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.") logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we # Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part. # loaded a pre-trained BN and only fine-tuning the affine part.
do_sync_bn = (sync_statistics is not None) and training
if sync_statistics is None or not (training and ctx.is_training): if not do_sync_bn:
# Use the builtin layer for anything except for sync-bn
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
with rename_get_variable( with rename_get_variable(
{'moving_mean': 'mean/EMA', {'moving_mean': 'mean/EMA',
'moving_variance': 'variance/EMA'}): 'moving_variance': 'variance/EMA'}):
tf_args = dict( tf_args = dict(
axis=axis, axis=axis,
momentum=momentum, epsilon=epsilon, momentum=momentum, epsilon=epsilon,
...@@ -204,16 +232,17 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -204,16 +232,17 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
layer = tf.layers.BatchNormalization(**tf_args) layer = tf.layers.BatchNormalization(**tf_args)
xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope()) xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
# maintain EMA only on one GPU is OK, even in replicated mode. # Add EMA variables to the correct collection
# because during training, EMA isn't used
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
for v in layer.non_trainable_variables: for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable): if isinstance(v, tf.Variable):
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v) tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if not ctx.is_main_training_tower or internal_update:
restore_collection(coll_bk)
if training and internal_update: if not do_ema_update:
restore_collection(coll_bk)
if do_ema_update and ema_update == "internal":
# Implement "internal" update.
restore_collection(coll_bk)
assert layer.updates assert layer.updates
with tf.control_dependencies(layer.updates): with tf.control_dependencies(layer.updates):
ret = tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
...@@ -301,8 +330,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -301,8 +330,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
inputs, batch_mean, batch_var, inputs, batch_mean, batch_var,
beta, gamma, epsilon) beta, gamma, epsilon)
if ctx.is_main_training_tower: if do_ema_update:
ret = update_bn_ema( ret = internal_update_bn_ema(
xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var, momentum) xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var, momentum)
else: else:
ret = tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment