Commit bde13a8d authored by Yuxin Wu's avatar Yuxin Wu

always use non-fused op for BatchNorm inference; support NCHW for BatchRenorm

parent f5a1a67c
......@@ -130,6 +130,14 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
return tf.identity(xn, name='output')
def reshape_for_bn(param, ndims, chan, data_format):
if ndims == 2:
shape = [1, chan]
else:
shape = [1, 1, 1, chan] if data_format == 'NHWC' else [1, chan, 1, 1]
return tf.reshape(param, shape)
@layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True,
......@@ -168,47 +176,48 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
with the official inceptionv3 example).
"""
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
if len(shape) == 2:
data_format = 'NHWC' # error using NCHW? (see #190)
ndims = len(shape)
assert ndims in [2, 4]
if ndims == 2:
data_format = 'NHWC'
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias, gamma_init)
ctx = get_current_tower_context()
if use_local_stat is None:
use_local_stat = ctx.is_training
if use_local_stat != ctx.is_training:
elif use_local_stat != ctx.is_training:
# we allow the use of local_stat in testing (only print warnings)
# because it is useful to certain applications.
logger.warn("[BatchNorm] use_local_stat != is_training")
if use_local_stat:
if ndims == 2:
x = tf.reshape(x, [-1, 1, 1, n_out]) # fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
x, gamma, beta, epsilon=epsilon,
is_training=True, data_format=data_format)
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
assert not ctx.is_training, "In training, local statistics has to be used!"
if data_format == 'NCHW':
# fused is slower in inference, but support NCHW
xn, _, _ = tf.nn.fused_batch_norm(
x, gamma, beta,
moving_mean, moving_var,
epsilon=epsilon, is_training=False, data_format=data_format)
# non-fused op is faster for inference
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]]
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
else:
xn = tf.nn.batch_normalization( # work only for NHWC when moving_mean is a vector
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if len(shape) == 2:
axis = [2, 3] if data_format == 'NCHW' else [1, 2]
xn = tf.squeeze(xn, axis)
# maintain EMA only on one GPU.
if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
......@@ -219,7 +228,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
# TODO support NCHW
@layer_register(log_shape=False)
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True):
use_scale=True, use_bias=True, data_format='NHWC'):
"""
Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
......@@ -244,10 +253,16 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
"""
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
n_out = shape[-1]
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
ndims = len(shape)
assert ndims in [2, 4]
if ndims == 2:
data_format = 'NHWC' # error using NCHW? (see #190)
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchRenorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(
n_out, use_scale, use_bias, tf.constant_initializer(1.0))
......@@ -257,21 +272,34 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
# different usage comes out in the future.
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True)
if ndims == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
x, gamma, beta, epsilon=epsilon, is_training=True, data_format=data_format)
inv_sigma = tf.rsqrt(moving_var, 'inv_sigma')
r = tf.stop_gradient(tf.clip_by_value(
tf.sqrt(batch_var) * inv_sigma, 1.0 / rmax, rmax))
d = tf.stop_gradient(tf.clip_by_value(
(batch_mean - moving_mean) * inv_sigma,
-dmax, dmax))
r = reshape_for_bn(r, ndims, n_out, data_format)
d = reshape_for_bn(d, ndims, n_out, data_format)
xn = xn * r + d
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]]
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
else:
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if len(shape) == 2:
xn = tf.squeeze(xn, [1, 2])
if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment