Commit fe049874 authored by Yuxin Wu's avatar Yuxin Wu

Prefer fused_batch_norm in bn inference.

parent a46de342
......@@ -151,11 +151,11 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
# non-fused op is faster for inference # TODO test if this is still true
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
for _ in [gamma, beta, moving_mean, moving_var]]
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon)
if ndims == 4:
xn, _, _ = tf.nn.fused_batch_norm(
x, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment