Commit 1a1ec3db authored by Yuxin Wu's avatar Yuxin Wu

support more batchnorm options (#627)

parent 013565d6
...@@ -30,16 +30,18 @@ __all__ = ['BatchNorm', 'BatchRenorm'] ...@@ -30,16 +30,18 @@ __all__ = ['BatchNorm', 'BatchRenorm']
'decay': 'momentum', 'decay': 'momentum',
'use_local_stat': 'training' 'use_local_stat': 'training'
}) })
def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True, center=True, scale=True,
beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(), gamma_initializer=tf.ones_initializer(),
virtual_batch_size=None,
data_format='channels_last', data_format='channels_last',
internal_update=False): internal_update=False):
""" """
Mostly equivalent to `tf.layers.batch_normalization`, but different in Mostly equivalent to `tf.layers.batch_normalization`, but different in
the following: the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored. 1. Accepts `data_format` when `axis` is None. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different. 2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from `TowerContext`. 3. Default value for `training` is automatically obtained from `TowerContext`.
4. Support the `internal_update` option. 4. Support the `internal_update` option.
...@@ -74,11 +76,13 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -74,11 +76,13 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4], ndims assert ndims in [2, 4], ndims
if ndims == 2:
data_format = 'NHWC' if axis is None:
axis = 1 if ndims == 2:
else: data_format = 'NHWC'
axis = 1 if data_format == 'NCHW' else 3 axis = 1
else:
axis = 1 if data_format == 'NCHW' else 3
# parse training/ctx # parse training/ctx
ctx = get_current_tower_context() ctx = get_current_tower_context()
...@@ -102,7 +106,9 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -102,7 +106,9 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
axis=axis, axis=axis,
momentum=momentum, epsilon=epsilon, momentum=momentum, epsilon=epsilon,
center=center, scale=scale, center=center, scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer, gamma_initializer=gamma_initializer,
virtual_batch_size=virtual_batch_size,
fused=True fused=True
) )
xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope()) xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment