Commit 65d0f0b9 authored by Yuxin Wu's avatar Yuxin Wu

use fused_batch_norm when use_local_stat=False&is_training

parent 81f4b575
...@@ -94,12 +94,17 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -94,12 +94,17 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
* ``variance/EMA``: the moving average of variance. * ``variance/EMA``: the moving average of variance.
Note: Note:
In multi-GPU training, moving averages across GPUs are not aggregated. 1. About multi-GPU training: moving averages across GPUs are not aggregated.
This is consistent with most frameworks. Batch statistics are computed indepdently. This is consistent with most frameworks.
2. Combinations of ``use_local_stat`` and ``ctx.is_training``:
However, all GPUs use the moving averages on the first GPU (instead of * ``use_local_stat == is_training``: standard BN, EMA are
their own), this is inconsistent with most frameworks (but consistent maintained during training and used during inference.
with the official inceptionv3 example). * ``use_local_stat and not is_training``: still use local (batch)
statistics in inference.
* ``not use_local_stat and is_training``: use EMA to normalize in
training. This is useful when you load a pre-trained BN and
don't want to fine tune the EMA. EMA will not be updated in
this case.
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
...@@ -131,8 +136,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -131,8 +136,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
xn = tf.squeeze(xn, [1, 2]) xn = tf.squeeze(xn, [1, 2])
else: else:
if ctx.is_training: if ctx.is_training:
if ctx.index == 0: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.") logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# non-fused op is faster for inference # Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn, _, _ = tf.nn.fused_batch_norm(
x, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
# non-fused op is faster for inference # TODO test if this is still true
if ndims == 4 and data_format == 'NCHW': if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format) [g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]] for _ in [gamma, beta, moving_mean, moving_var]]
......
...@@ -44,7 +44,7 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -44,7 +44,7 @@ def get_default_sess_config(mem_fraction=0.99):
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True conf.gpu_options.allow_growth = True
# Hurt performance in 8xP100 training # May hurt performance
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 # conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return conf return conf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment