Commit 6b10019e authored by Yuxin Wu's avatar Yuxin Wu

Support internal_update in BN

parent 35527038
......@@ -8,7 +8,7 @@ import copy
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
MapData, imgaug, TestDataSpeed, PrefetchDataZMQ,
imgaug, TestDataSpeed, PrefetchDataZMQ, MultiProcessMapData,
MapDataComponent, DataFromList)
# import tensorpack.utils.viz as tpviz
......@@ -251,8 +251,7 @@ def get_train_dataflow(add_mask=False):
# tpviz.interactive_imshow(viz)
return ret
ds = MapData(ds, preprocess)
ds = PrefetchDataZMQ(ds, 1)
ds = MultiProcessMapData(ds, 3, preprocess)
return ds
......
......@@ -37,7 +37,8 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
......@@ -46,9 +47,10 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
# TODO add an option, and maybe enable it for replica mode?
# with tf.control_dependencies([update_op1, update_op2]):
# return tf.identity(xn, name='output')
if internal_update:
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
else:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return xn
......@@ -65,7 +67,8 @@ def reshape_for_bn(param, ndims, chan, data_format):
@layer_register()
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True,
gamma_init=tf.constant_initializer(1.0), data_format='NHWC'):
gamma_init=tf.constant_initializer(1.0), data_format='NHWC',
internal_update=False):
"""
Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by
......@@ -79,6 +82,9 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale).
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
which will be slightly slower.
Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x.
......@@ -161,7 +167,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
add_model_variable(moving_mean)
add_model_variable(moving_var)
if ctx.is_main_training_tower and use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay, internal_update)
else:
ret = tf.identity(xn, name='output')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment