Commit 43d2bffb authored by Yuxin Wu's avatar Yuxin Wu

Use some tflayers argument names in batchnorm (#627)

parent 8e5a46a4
...@@ -13,6 +13,7 @@ from ..tfutils.tower import get_current_tower_context ...@@ -13,6 +13,7 @@ from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args
__all__ = ['BatchNorm', 'BatchRenorm'] __all__ = ['BatchNorm', 'BatchRenorm']
...@@ -66,9 +67,17 @@ def reshape_for_bn(param, ndims, chan, data_format): ...@@ -66,9 +67,17 @@ def reshape_for_bn(param, ndims, chan, data_format):
@layer_register() @layer_register()
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, @convert_to_tflayer_args(
use_scale=True, use_bias=True, args_names=[],
gamma_init=tf.constant_initializer(1.0), name_mapping={
'use_bias': 'center',
'use_scale': 'scale',
'gamma_init': 'gamma_initializer',
'decay': 'momentum'
})
def BatchNorm(x, use_local_stat=None, momentum=0.9, epsilon=1e-5,
scale=True, center=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last', data_format='channels_last',
internal_update=False): internal_update=False):
""" """
...@@ -80,10 +89,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -80,10 +89,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format. x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average. use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference. Defaults to True in training and False in inference.
decay (float): decay rate of moving average. momentum (float): momentum of moving average.
epsilon (float): epsilon to avoid divide-by-zero. epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not. scale, center (bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale). gamma_initializer: initializer for gamma (the scale).
internal_update (bool): if False, add EMA update ops to internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
which will be slightly slower. which will be slightly slower.
...@@ -122,7 +131,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -122,7 +131,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
else: else:
n_out = shape[-1] # channel n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!" assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias, gamma_init) beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, scale, center, gamma_initializer)
ctx = get_current_tower_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
...@@ -170,21 +179,29 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -170,21 +179,29 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
add_model_variable(moving_mean) add_model_variable(moving_mean)
add_model_variable(moving_var) add_model_variable(moving_var)
if ctx.is_main_training_tower and use_local_stat: if ctx.is_main_training_tower and use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay, internal_update) ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, momentum, internal_update)
else: else:
ret = tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var) vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
if use_scale: if scale:
vh.gamma = gamma vh.gamma = gamma
if use_bias: if center:
vh.beta = beta vh.beta = beta
return ret return ret
@layer_register() @layer_register()
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, @convert_to_tflayer_args(
use_scale=True, use_bias=True, gamma_init=None, args_names=[],
name_mapping={
'use_bias': 'center',
'use_scale': 'scale',
'gamma_init': 'gamma_initializer',
'decay': 'momentum'
})
def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
scale=True, bias=True, gamma_initializer=None,
data_format='channels_last'): data_format='channels_last'):
""" """
Batch Renormalization layer, as described in the paper: Batch Renormalization layer, as described in the paper:
...@@ -221,15 +238,15 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -221,15 +238,15 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
layer = tf.layers.BatchNormalization( layer = tf.layers.BatchNormalization(
axis=1 if data_format == 'channels_first' else 3, axis=1 if data_format == 'channels_first' else 3,
momentum=decay, epsilon=epsilon, momentum=momentum, epsilon=epsilon,
center=use_bias, scale=use_scale, center=center, scale=scale,
renorm=True, renorm=True,
renorm_clipping={ renorm_clipping={
'rmin': 1.0 / rmax, 'rmin': 1.0 / rmax,
'rmax': rmax, 'rmax': rmax,
'dmax': dmax}, 'dmax': dmax},
renorm_momentum=0.99, renorm_momentum=0.99,
gamma_initializer=gamma_init, gamma_initializer=gamma_initializer,
fused=False) fused=False)
xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope()) xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope())
...@@ -246,8 +263,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -246,8 +263,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
# TODO not sure whether to add moving_mean/moving_var to VH now # TODO not sure whether to add moving_mean/moving_var to VH now
vh = ret.variables = VariableHolder() vh = ret.variables = VariableHolder()
if use_scale: if scale:
vh.gamma = layer.gamma vh.gamma = layer.gamma
if use_bias: if center:
vh.beta = layer.beta vh.beta = layer.beta
return ret return ret
...@@ -57,7 +57,6 @@ def AvgPooling( ...@@ -57,7 +57,6 @@ def AvgPooling(
@layer_register(log_shape=True) @layer_register(log_shape=True)
@convert_to_tflayer_args(args_names=[], name_mapping={})
def GlobalAvgPooling(x, data_format='channels_last'): def GlobalAvgPooling(x, data_format='channels_last'):
""" """
Global average pooling as in the paper `Network In Network Global average pooling as in the paper `Network In Network
...@@ -70,6 +69,7 @@ def GlobalAvgPooling(x, data_format='channels_last'): ...@@ -70,6 +69,7 @@ def GlobalAvgPooling(x, data_format='channels_last'):
tf.Tensor: a NC tensor named ``output``. tf.Tensor: a NC tensor named ``output``.
""" """
assert x.shape.ndims == 4 assert x.shape.ndims == 4
data_format = get_data_format(data_format)
axis = [1, 2] if data_format == 'channels_last' else [2, 3] axis = [1, 2] if data_format == 'channels_last' else [2, 3]
return tf.reduce_mean(x, axis, name='output') return tf.reduce_mean(x, axis, name='output')
......
...@@ -46,7 +46,7 @@ def describe_trainable_vars(): ...@@ -46,7 +46,7 @@ def describe_trainable_vars():
summary_msg = colored( summary_msg = colored(
"\nTotal #vars={}, #params={}, size={:.02f}MB".format( "\nTotal #vars={}, #params={}, size={:.02f}MB".format(
len(data), total, size_mb), 'cyan') len(data), total, size_mb), 'cyan')
logger.info(colored("Model Parameters: \n", 'cyan') + table + summary_msg) logger.info(colored("Trainable Variables: \n", 'cyan') + table + summary_msg)
def get_shape_str(tensors): def get_shape_str(tensors):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment