Commit aa1f82f7 authored by Yuxin Wu's avatar Yuxin Wu

fix bug in f6ede612

parent eafe564b
...@@ -163,7 +163,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -163,7 +163,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
don't want to update it. don't want to update it.
2. As long as `training=True`, `sync_statistics` and `ema_update` option will take effect. 2. As long as `training=True`, `sync_statistics` and `ema_update` option will take effect.
""" """
# parse training/ctx
ctx = get_current_tower_context() ctx = get_current_tower_context()
if training is None:
training = ctx.is_training
training = bool(training)
# parse shapes # parse shapes
data_format = get_data_format(data_format, keras_mode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
...@@ -200,10 +205,6 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -200,10 +205,6 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
TF_version = get_tf_version_tuple() TF_version = get_tf_version_tuple()
# parse training/ctx
if training is None:
training = ctx.is_training
training = bool(training)
freeze_bn_backward = not training and ctx.is_training freeze_bn_backward = not training and ctx.is_training
if freeze_bn_backward: if freeze_bn_backward:
assert TF_version >= (1, 4), \ assert TF_version >= (1, 4), \
...@@ -212,6 +213,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -212,6 +213,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.") logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we # Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part. # loaded a pre-trained BN and only fine-tuning the affine part.
do_sync_bn = (sync_statistics is not None) and training do_sync_bn = (sync_statistics is not None) and training
if not do_sync_bn: if not do_sync_bn:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment