Commit 8c778cc2 authored by Yuxin Wu's avatar Yuxin Wu

save MODEL_VARIABLES in batchnorm freeze mode

parent 523dfa1c
......@@ -45,9 +45,6 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
# Only add to model var when we update them
add_model_variable(moving_mean)
add_model_variable(moving_var)
# TODO add an option, and maybe enable it for replica mode?
# with tf.control_dependencies([update_op1, update_op2]):
......@@ -160,7 +157,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if ctx.is_main_training_tower and use_local_stat:
if ctx.is_main_training_tower:
add_model_variable(moving_mean)
add_model_variable(moving_var)
if use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
ret = tf.identity(xn, name='output')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment