Commit 7c3b404a authored by Yuxin Wu's avatar Yuxin Wu

add version check about BN freeze

parent a2da829b
...@@ -9,6 +9,7 @@ from tensorflow.python.training import moving_averages ...@@ -9,6 +9,7 @@ from tensorflow.python.training import moving_averages
from ..utils import logger from ..utils import logger
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
...@@ -135,6 +136,9 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -135,6 +136,9 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
xn = tf.squeeze(xn, [1, 2]) xn = tf.squeeze(xn, [1, 2])
else: else:
if ctx.is_training: if ctx.is_training:
assert get_tf_version_number() >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.index == 0: # only warn in first tower if ctx.index == 0: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.") logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we # Using moving_mean/moving_variance in training, which means we
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment