Commit 65d0f0b9 authored by Yuxin Wu's avatar Yuxin Wu

use fused_batch_norm when use_local_stat=False&is_training

parent 81f4b575
......@@ -94,12 +94,17 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
* ``variance/EMA``: the moving average of variance.
Note:
In multi-GPU training, moving averages across GPUs are not aggregated.
This is consistent with most frameworks.
However, all GPUs use the moving averages on the first GPU (instead of
their own), this is inconsistent with most frameworks (but consistent
with the official inceptionv3 example).
1. About multi-GPU training: moving averages across GPUs are not aggregated.
Batch statistics are computed indepdently. This is consistent with most frameworks.
2. Combinations of ``use_local_stat`` and ``ctx.is_training``:
* ``use_local_stat == is_training``: standard BN, EMA are
maintained during training and used during inference.
* ``use_local_stat and not is_training``: still use local (batch)
statistics in inference.
* ``not use_local_stat and is_training``: use EMA to normalize in
training. This is useful when you load a pre-trained BN and
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
shape = x.get_shape().as_list()
ndims = len(shape)
......@@ -131,16 +136,24 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
xn = tf.squeeze(xn, [1, 2])
else:
if ctx.is_training:
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# non-fused op is faster for inference
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]]
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
if ctx.index == 0: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn, _, _ = tf.nn.fused_batch_norm(
x, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
# non-fused op is faster for inference # TODO test if this is still true
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]]
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
else:
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
......
......@@ -44,7 +44,7 @@ def get_default_sess_config(mem_fraction=0.99):
conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True
# Hurt performance in 8xP100 training
# May hurt performance
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return conf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment