Commit 21c49469 authored by Yuxin Wu's avatar Yuxin Wu

fix BatchNorm+fp16 in inference mode

parent 56dd2390
...@@ -183,6 +183,11 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -183,6 +183,11 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
tf_args['virtual_batch_size'] = virtual_batch_size tf_args['virtual_batch_size'] = virtual_batch_size
else: else:
assert virtual_batch_size is None, "Feature not supported in this version of TF!" assert virtual_batch_size is None, "Feature not supported in this version of TF!"
use_fp16 = inputs.dtype == tf.float16
if use_fp16:
# non-fused does not support fp16; fused does not support all layouts.
# we made our best guess here
tf_args['fused'] = True
layer = tf.layers.BatchNormalization(**tf_args) layer = tf.layers.BatchNormalization(**tf_args)
xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope()) xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment