Commit bde13a8d authored by Yuxin Wu's avatar Yuxin Wu

always use non-fused op for BatchNorm inference; support NCHW for BatchRenorm

parent f5a1a67c
...@@ -130,6 +130,14 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay): ...@@ -130,6 +130,14 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
return tf.identity(xn, name='output') return tf.identity(xn, name='output')
def reshape_for_bn(param, ndims, chan, data_format):
if ndims == 2:
shape = [1, chan]
else:
shape = [1, 1, 1, chan] if data_format == 'NHWC' else [1, chan, 1, 1]
return tf.reshape(param, shape)
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, use_scale=True, use_bias=True,
...@@ -168,47 +176,48 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -168,47 +176,48 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
with the official inceptionv3 example). with the official inceptionv3 example).
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) in [2, 4] ndims = len(shape)
if len(shape) == 2: assert ndims in [2, 4]
data_format = 'NHWC' # error using NCHW? (see #190) if ndims == 2:
data_format = 'NHWC'
if data_format == 'NCHW': if data_format == 'NCHW':
n_out = shape[1] n_out = shape[1]
else: else:
n_out = shape[-1] # channel n_out = shape[-1] # channel
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!" assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias, gamma_init) beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias, gamma_init)
ctx = get_current_tower_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
use_local_stat = ctx.is_training use_local_stat = ctx.is_training
if use_local_stat != ctx.is_training: elif use_local_stat != ctx.is_training:
# we allow the use of local_stat in testing (only print warnings) # we allow the use of local_stat in testing (only print warnings)
# because it is useful to certain applications. # because it is useful to certain applications.
logger.warn("[BatchNorm] use_local_stat != is_training") logger.warn("[BatchNorm] use_local_stat != is_training")
if use_local_stat: if use_local_stat:
if ndims == 2:
x = tf.reshape(x, [-1, 1, 1, n_out]) # fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn, batch_mean, batch_var = tf.nn.fused_batch_norm( xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
x, gamma, beta, epsilon=epsilon, x, gamma, beta, epsilon=epsilon,
is_training=True, data_format=data_format) is_training=True, data_format=data_format)
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else: else:
assert not ctx.is_training, "In training, local statistics has to be used!" assert not ctx.is_training, "In training, local statistics has to be used!"
if data_format == 'NCHW': # non-fused op is faster for inference
# fused is slower in inference, but support NCHW if ndims == 4 and data_format == 'NCHW':
xn, _, _ = tf.nn.fused_batch_norm( [g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
x, gamma, beta, for _ in [gamma, beta, moving_mean, moving_var]]
moving_mean, moving_var, xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
epsilon=epsilon, is_training=False, data_format=data_format)
else: else:
xn = tf.nn.batch_normalization( # work only for NHWC when moving_mean is a vector # avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
if len(shape) == 2:
axis = [2, 3] if data_format == 'NCHW' else [1, 2]
xn = tf.squeeze(xn, axis)
# maintain EMA only on one GPU. # maintain EMA only on one GPU.
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
...@@ -219,7 +228,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -219,7 +228,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
# TODO support NCHW # TODO support NCHW
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True): use_scale=True, use_bias=True, data_format='NHWC'):
""" """
Batch Renormalization layer, as described in the paper: Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models `Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
...@@ -244,10 +253,16 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -244,10 +253,16 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) in [2, 4] ndims = len(shape)
n_out = shape[-1] assert ndims in [2, 4]
if len(shape) == 2: if ndims == 2:
x = tf.reshape(x, [-1, 1, 1, n_out]) data_format = 'NHWC' # error using NCHW? (see #190)
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchRenorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables( beta, gamma, moving_mean, moving_var = get_bn_variables(
n_out, use_scale, use_bias, tf.constant_initializer(1.0)) n_out, use_scale, use_bias, tf.constant_initializer(1.0))
...@@ -257,21 +272,34 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -257,21 +272,34 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
# different usage comes out in the future. # different usage comes out in the future.
if use_local_stat: if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta, if ndims == 2:
epsilon=epsilon, is_training=True) x = tf.reshape(x, [-1, 1, 1, n_out])
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
x, gamma, beta, epsilon=epsilon, is_training=True, data_format=data_format)
inv_sigma = tf.rsqrt(moving_var, 'inv_sigma') inv_sigma = tf.rsqrt(moving_var, 'inv_sigma')
r = tf.stop_gradient(tf.clip_by_value( r = tf.stop_gradient(tf.clip_by_value(
tf.sqrt(batch_var) * inv_sigma, 1.0 / rmax, rmax)) tf.sqrt(batch_var) * inv_sigma, 1.0 / rmax, rmax))
d = tf.stop_gradient(tf.clip_by_value( d = tf.stop_gradient(tf.clip_by_value(
(batch_mean - moving_mean) * inv_sigma, (batch_mean - moving_mean) * inv_sigma,
-dmax, dmax)) -dmax, dmax))
r = reshape_for_bn(r, ndims, n_out, data_format)
d = reshape_for_bn(d, ndims, n_out, data_format)
xn = xn * r + d xn = xn * r + d
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]]
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
else: else:
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
if len(shape) == 2:
xn = tf.squeeze(xn, [1, 2])
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment