Commit 07783edb authored by Yuxin Wu's avatar Yuxin Wu

Sync BatchNorm statistics with nccl or horovod

parent bffcfc1b
...@@ -36,7 +36,6 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init): ...@@ -36,7 +36,6 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
def update_bn_ema(xn, batch_mean, batch_var, def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update): moving_mean, moving_var, decay, internal_update):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average( update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False, moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op') name='mean_ema_op')
...@@ -147,7 +146,6 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -147,7 +146,6 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
mean=moving_mean, variance=moving_var, epsilon=epsilon, mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False) data_format=data_format, is_training=False)
else: else:
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
inputs, moving_mean, moving_var, beta, gamma, epsilon) inputs, moving_mean, moving_var, beta, gamma, epsilon)
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
import re
import six
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
...@@ -19,6 +22,42 @@ __all__ = ['BatchNorm', 'BatchRenorm'] ...@@ -19,6 +22,42 @@ __all__ = ['BatchNorm', 'BatchRenorm']
# eps: torch: 1e-5. Lasagne: 1e-4 # eps: torch: 1e-5. Lasagne: 1e-4
def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init):
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=beta_init)
else:
beta = tf.zeros([n_out], name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
else:
gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(1.0), trainable=False)
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update):
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
if internal_update:
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
else:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return tf.identity(xn, name='output')
@layer_register() @layer_register()
@convert_to_tflayer_args( @convert_to_tflayer_args(
args_names=[], args_names=[],
...@@ -35,20 +74,30 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -35,20 +74,30 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
gamma_initializer=tf.ones_initializer(), gamma_initializer=tf.ones_initializer(),
virtual_batch_size=None, virtual_batch_size=None,
data_format='channels_last', data_format='channels_last',
internal_update=False): internal_update=False,
sync_statistics=None):
""" """
Mostly equivalent to `tf.layers.batch_normalization`, but different in Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
the following due to historical reasons: in the following:
1. Accepts `data_format` when `axis` is None. For 2D input, this argument will be ignored. 1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different. 2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from `TowerContext`. 3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
4. Support the `internal_update` option, which can be very useful in certain models. 4. Support the `internal_update` option, which enables the use of BatchNorm layer inside conditionals.
5. Support the `sync_statistics` option, which is very useful in small-batch models.
Args: Args:
internal_update (bool): if False, add EMA update ops to internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
by control dependencies. by control dependencies.
They are very similar in speed, but `internal_update=True` can be used
when you have conditionals in your model, or when you have multiple networks to train.
sync_statistics: either None or "nccl". By default (None), it uses statistics of the input tensor to normalize.
When set to "nccl", this layer must be used under tensorpack multi-gpu trainers,
and it then uses per-machine (multiple GPU) statistics to normalize.
This option has no effect when not training.
The option is also known as "Cross-GPU BatchNorm" as mentioned in https://arxiv.org/abs/1711.07240.
Variable Names: Variable Names:
...@@ -58,9 +107,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -58,9 +107,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
* ``variance/EMA``: the moving average of variance. * ``variance/EMA``: the moving average of variance.
Note: Note:
1. About multi-GPU training: moving averages across GPUs are not aggregated. 1. Combinations of ``training`` and ``ctx.is_training``:
Batch statistics are computed independently. This is consistent with most frameworks.
2. Combinations of ``training`` and ``ctx.is_training``:
* ``training == ctx.is_training``: standard BN, EMA are * ``training == ctx.is_training``: standard BN, EMA are
maintained during training and used during inference. This is maintained during training and used during inference. This is
the default. the default.
...@@ -75,6 +122,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -75,6 +122,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4], ndims assert ndims in [2, 4], ndims
if sync_statistics is not None:
sync_statistics = sync_statistics.lower()
assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics
if axis is None: if axis is None:
if ndims == 2: if ndims == 2:
...@@ -82,6 +132,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -82,6 +132,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
axis = 1 axis = 1
else: else:
axis = 1 if data_format == 'NCHW' else 3 axis = 1 if data_format == 'NCHW' else 3
else:
data_format = 'NCHW' if axis == 1 else 'NHWC'
num_chan = shape[axis]
# parse training/ctx # parse training/ctx
ctx = get_current_tower_context() ctx = get_current_tower_context()
...@@ -98,36 +151,28 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -98,36 +151,28 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
# Using moving_mean/moving_variance in training, which means we # Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part. # loaded a pre-trained BN and only fine-tuning the affine part.
if sync_statistics is None or not (training and ctx.is_training):
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
with rename_get_variable( with rename_get_variable(
{'moving_mean': 'mean/EMA', {'moving_mean': 'mean/EMA',
'moving_variance': 'variance/EMA'}): 'moving_variance': 'variance/EMA'}):
if TF_version >= 1.5: tf_args = dict(
layer = tf.layers.BatchNormalization(
axis=axis, axis=axis,
momentum=momentum, epsilon=epsilon, momentum=momentum, epsilon=epsilon,
center=center, scale=scale, center=center, scale=scale,
beta_initializer=beta_initializer, beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer, gamma_initializer=gamma_initializer,
virtual_batch_size=virtual_batch_size,
fused=True, fused=True,
_reuse=tf.get_variable_scope().reuse _reuse=tf.get_variable_scope().reuse)
) if TF_version >= 1.5:
tf_args['virtual_batch_size'] = virtual_batch_size
else: else:
assert virtual_batch_size is None, "Feature not supported in this version of TF!" assert virtual_batch_size is None, "Feature not supported in this version of TF!"
layer = tf.layers.BatchNormalization( layer = tf.layers.BatchNormalization(**tf_args)
axis=axis,
momentum=momentum, epsilon=epsilon,
center=center, scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
fused=True,
_reuse=tf.get_variable_scope().reuse
)
xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope()) xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
# maintain EMA only on one GPU is OK, even in replicated mode. # maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA # because during training, EMA isn't used
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
for v in layer.non_trainable_variables: for v in layer.non_trainable_variables:
add_model_variable(v) add_model_variable(v)
...@@ -150,6 +195,72 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -150,6 +195,72 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
vh.gamma = layer.gamma vh.gamma = layer.gamma
if center: if center:
vh.beta = layer.beta vh.beta = layer.beta
else:
red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2])
new_shape = None # don't need to reshape unless ...
if ndims == 4 and axis == 1:
new_shape = [1, num_chan, 1, 1]
batch_mean = tf.reduce_mean(inputs, axis=red_axis)
batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
if sync_statistics == 'nccl':
if six.PY3 and TF_version <= 1.8 and ctx.is_main_training_tower:
logger.warn("A TensorFlow bug will cause cross-GPU BatchNorm to fail. "
"Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360")
from tensorflow.contrib.nccl.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
num_dev = ctx.total
batch_mean = gen_nccl_ops.nccl_all_reduce(
input=batch_mean,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
batch_mean_square = gen_nccl_ops.nccl_all_reduce(
input=batch_mean_square,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
elif sync_statistics == 'horovod':
# Require https://github.com/uber/horovod/pull/331
# Proof-of-concept, not ready yet.
import horovod.tensorflow as hvd
batch_mean = hvd.allreduce(batch_mean, average=True)
batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
batch_var = batch_mean_square - tf.square(batch_mean)
batch_mean_vec = batch_mean
batch_var_vec = batch_var
beta, gamma, moving_mean, moving_var = get_bn_variables(
num_chan, scale, center, beta_initializer, gamma_initializer)
if new_shape is not None:
batch_mean = tf.reshape(batch_mean, new_shape)
batch_var = tf.reshape(batch_var, new_shape)
r_gamma = tf.reshape(gamma, new_shape)
r_beta = tf.reshape(beta, new_shape)
else:
r_gamma, r_beta = gamma, beta
xn = tf.nn.batch_normalization(
inputs, batch_mean, batch_var, r_beta, r_gamma, epsilon)
if ctx.is_main_training_tower:
ret = update_bn_ema(
xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var,
momentum, internal_update)
else:
ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(
moving_mean=moving_mean,
mean=moving_mean, # for backward-compatibility
moving_variance=moving_var,
variance=moving_var) # for backward-compatibility
if scale:
vh.gamma = gamma
if center:
vh.beta = beta
return ret return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment