Commit c9e03b73 authored by Yuxin Wu's avatar Yuxin Wu

Add MODEL_VARIABLES in SyncBN (fix #1270)

parent 610a8b9b
......@@ -37,6 +37,10 @@ def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init):
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(1.0), trainable=False)
if get_current_tower_context().is_main_training_tower:
for v in [moving_mean, moving_var]:
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
return beta, gamma, moving_mean, moving_var
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment