Commit 43d2bffb authored by Yuxin Wu's avatar Yuxin Wu

Use some tflayers argument names in batchnorm (#627)

parent 8e5a46a4
......@@ -13,6 +13,7 @@ from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args
__all__ = ['BatchNorm', 'BatchRenorm']
......@@ -66,9 +67,17 @@ def reshape_for_bn(param, ndims, chan, data_format):
@layer_register()
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True,
gamma_init=tf.constant_initializer(1.0),
@convert_to_tflayer_args(
args_names=[],
name_mapping={
'use_bias': 'center',
'use_scale': 'scale',
'gamma_init': 'gamma_initializer',
'decay': 'momentum'
})
def BatchNorm(x, use_local_stat=None, momentum=0.9, epsilon=1e-5,
scale=True, center=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last',
internal_update=False):
"""
......@@ -80,10 +89,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference.
decay (float): decay rate of moving average.
momentum (float): momentum of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale).
scale, center (bool): whether to use the extra affine transformation or not.
gamma_initializer: initializer for gamma (the scale).
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
which will be slightly slower.
......@@ -122,7 +131,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias, gamma_init)
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, scale, center, gamma_initializer)
ctx = get_current_tower_context()
if use_local_stat is None:
......@@ -170,21 +179,29 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
add_model_variable(moving_mean)
add_model_variable(moving_var)
if ctx.is_main_training_tower and use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay, internal_update)
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, momentum, internal_update)
else:
ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
if use_scale:
if scale:
vh.gamma = gamma
if use_bias:
if center:
vh.beta = beta
return ret
@layer_register()
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, gamma_init=None,
@convert_to_tflayer_args(
args_names=[],
name_mapping={
'use_bias': 'center',
'use_scale': 'scale',
'gamma_init': 'gamma_initializer',
'decay': 'momentum'
})
def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
scale=True, bias=True, gamma_initializer=None,
data_format='channels_last'):
"""
Batch Renormalization layer, as described in the paper:
......@@ -221,15 +238,15 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
layer = tf.layers.BatchNormalization(
axis=1 if data_format == 'channels_first' else 3,
momentum=decay, epsilon=epsilon,
center=use_bias, scale=use_scale,
momentum=momentum, epsilon=epsilon,
center=center, scale=scale,
renorm=True,
renorm_clipping={
'rmin': 1.0 / rmax,
'rmax': rmax,
'dmax': dmax},
renorm_momentum=0.99,
gamma_initializer=gamma_init,
gamma_initializer=gamma_initializer,
fused=False)
xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope())
......@@ -246,8 +263,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
# TODO not sure whether to add moving_mean/moving_var to VH now
vh = ret.variables = VariableHolder()
if use_scale:
if scale:
vh.gamma = layer.gamma
if use_bias:
if center:
vh.beta = layer.beta
return ret
......@@ -57,7 +57,6 @@ def AvgPooling(
@layer_register(log_shape=True)
@convert_to_tflayer_args(args_names=[], name_mapping={})
def GlobalAvgPooling(x, data_format='channels_last'):
"""
Global average pooling as in the paper `Network In Network
......@@ -70,6 +69,7 @@ def GlobalAvgPooling(x, data_format='channels_last'):
tf.Tensor: a NC tensor named ``output``.
"""
assert x.shape.ndims == 4
data_format = get_data_format(data_format)
axis = [1, 2] if data_format == 'channels_last' else [2, 3]
return tf.reduce_mean(x, axis, name='output')
......
......@@ -46,7 +46,7 @@ def describe_trainable_vars():
summary_msg = colored(
"\nTotal #vars={}, #params={}, size={:.02f}MB".format(
len(data), total, size_mb), 'cyan')
logger.info(colored("Model Parameters: \n", 'cyan') + table + summary_msg)
logger.info(colored("Trainable Variables: \n", 'cyan') + table + summary_msg)
def get_shape_str(tensors):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment