Commit fe049874 authored by Yuxin Wu's avatar Yuxin Wu

Prefer fused_batch_norm in bn inference.

parent a46de342
...@@ -151,11 +151,11 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -151,11 +151,11 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
mean=moving_mean, variance=moving_var, epsilon=epsilon, mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False) data_format=data_format, is_training=False)
else: else:
# non-fused op is faster for inference # TODO test if this is still true if ndims == 4:
if ndims == 4 and data_format == 'NCHW': xn, _, _ = tf.nn.fused_batch_norm(
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format) x, gamma, beta,
for _ in [gamma, beta, moving_mean, moving_var]] mean=moving_mean, variance=moving_var, epsilon=epsilon,
xn = tf.nn.batch_normalization(x, mm, mv, b, g, epsilon) data_format=data_format, is_training=False)
else: else:
# avoid the reshape if possible (when channel is the last dimension) # avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment