Commit 6b10019e authored by Yuxin Wu's avatar Yuxin Wu

Support internal_update in BN

parent 35527038
...@@ -8,7 +8,7 @@ import copy ...@@ -8,7 +8,7 @@ import copy
from tensorpack.utils.argtools import memoized, log_once from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import ( from tensorpack.dataflow import (
MapData, imgaug, TestDataSpeed, PrefetchDataZMQ, imgaug, TestDataSpeed, PrefetchDataZMQ, MultiProcessMapData,
MapDataComponent, DataFromList) MapDataComponent, DataFromList)
# import tensorpack.utils.viz as tpviz # import tensorpack.utils.viz as tpviz
...@@ -251,8 +251,7 @@ def get_train_dataflow(add_mask=False): ...@@ -251,8 +251,7 @@ def get_train_dataflow(add_mask=False):
# tpviz.interactive_imshow(viz) # tpviz.interactive_imshow(viz)
return ret return ret
ds = MapData(ds, preprocess) ds = MultiProcessMapData(ds, 3, preprocess)
ds = PrefetchDataZMQ(ds, 1)
return ds return ds
......
...@@ -37,7 +37,8 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init): ...@@ -37,7 +37,8 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
return beta, gamma, moving_mean, moving_var return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay): def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update):
# TODO is there a way to use zero_debias in multi-GPU? # TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average( update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False, moving_mean, batch_mean, decay, zero_debias=False,
...@@ -46,12 +47,13 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay): ...@@ -46,12 +47,13 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
moving_var, batch_var, decay, zero_debias=False, moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op') name='var_ema_op')
# TODO add an option, and maybe enable it for replica mode? if internal_update:
# with tf.control_dependencies([update_op1, update_op2]): with tf.control_dependencies([update_op1, update_op2]):
# return tf.identity(xn, name='output') return tf.identity(xn, name='output')
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1) else:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
return xn tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return xn
def reshape_for_bn(param, ndims, chan, data_format): def reshape_for_bn(param, ndims, chan, data_format):
...@@ -65,7 +67,8 @@ def reshape_for_bn(param, ndims, chan, data_format): ...@@ -65,7 +67,8 @@ def reshape_for_bn(param, ndims, chan, data_format):
@layer_register() @layer_register()
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, use_scale=True, use_bias=True,
gamma_init=tf.constant_initializer(1.0), data_format='NHWC'): gamma_init=tf.constant_initializer(1.0), data_format='NHWC',
internal_update=False):
""" """
Batch Normalization layer, as described in the paper: Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by `Batch Normalization: Accelerating Deep Network Training by
...@@ -79,6 +82,9 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -79,6 +82,9 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
epsilon (float): epsilon to avoid divide-by-zero. epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not. use_scale, use_bias (bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale). gamma_init: initializer for gamma (the scale).
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
which will be slightly slower.
Returns: Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x. tf.Tensor: a tensor named ``output`` with the same shape of x.
...@@ -161,7 +167,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -161,7 +167,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
add_model_variable(moving_mean) add_model_variable(moving_mean)
add_model_variable(moving_var) add_model_variable(moving_var)
if ctx.is_main_training_tower and use_local_stat: if ctx.is_main_training_tower and use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay, internal_update)
else: else:
ret = tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment