Commit 07783edb authored by Yuxin Wu's avatar Yuxin Wu

Sync BatchNorm statistics with nccl or horovod

parent bffcfc1b
...@@ -36,7 +36,6 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init): ...@@ -36,7 +36,6 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
def update_bn_ema(xn, batch_mean, batch_var, def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update): moving_mean, moving_var, decay, internal_update):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average( update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False, moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op') name='mean_ema_op')
...@@ -147,7 +146,6 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -147,7 +146,6 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
mean=moving_mean, variance=moving_var, epsilon=epsilon, mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False) data_format=data_format, is_training=False)
else: else:
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
inputs, moving_mean, moving_var, beta, gamma, epsilon) inputs, moving_mean, moving_var, beta, gamma, epsilon)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment