Commit 07e464d8 authored by Yuxin Wu's avatar Yuxin Wu

standarize arg names in LayerNorm/InstanceNorm

parent 2ff9a5f4
...@@ -471,6 +471,8 @@ class TFDatasetInput(FeedfreeInput): ...@@ -471,6 +471,8 @@ class TFDatasetInput(FeedfreeInput):
self._spec = input_signature self._spec = input_signature
if self._dataset is not None: if self._dataset is not None:
types = self._dataset.output_types types = self._dataset.output_types
if len(types) == 1:
types = (types,)
spec_types = tuple(k.dtype for k in input_signature) spec_types = tuple(k.dtype for k in input_signature)
assert len(types) == len(spec_types), \ assert len(types) == len(spec_types), \
"Dataset and input signature have different length! {} != {}".format( "Dataset and input signature have different length! {} != {}".format(
......
...@@ -71,7 +71,7 @@ def internal_update_bn_ema(xn, batch_mean, batch_var, ...@@ -71,7 +71,7 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
'use_local_stat': 'training' 'use_local_stat': 'training'
}) })
@disable_autograph() @disable_autograph()
def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True, center=True, scale=True,
beta_initializer=tf.zeros_initializer(), beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(), gamma_initializer=tf.ones_initializer(),
...@@ -376,7 +376,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -376,7 +376,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
'gamma_init': 'gamma_initializer', 'gamma_init': 'gamma_initializer',
'decay': 'momentum' 'decay': 'momentum'
}) })
def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5, def BatchRenorm(x, rmax, dmax, *, momentum=0.9, epsilon=1e-5,
center=True, scale=True, gamma_initializer=None, center=True, scale=True, gamma_initializer=None,
data_format='channels_last'): data_format='channels_last'):
""" """
......
...@@ -5,16 +5,26 @@ ...@@ -5,16 +5,26 @@
from ..compat import tfv1 as tf # this should be avoided first in model code from ..compat import tfv1 as tf # this should be avoided first in model code
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from ..utils.develop import log_deprecated
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args
__all__ = ['LayerNorm', 'InstanceNorm'] __all__ = ['LayerNorm', 'InstanceNorm']
@layer_register() @layer_register()
@convert_to_tflayer_args(
args_names=[],
name_mapping={
'use_bias': 'center',
'use_scale': 'scale',
'gamma_init': 'gamma_initializer',
})
def LayerNorm( def LayerNorm(
x, epsilon=1e-5, x, epsilon=1e-5, *,
use_bias=True, use_scale=True, center=True, scale=True,
gamma_init=None, data_format='channels_last'): gamma_initializer=tf.ones_initializer(),
data_format='channels_last'):
""" """
Layer Normalization layer, as described in the paper: Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_. `Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
...@@ -22,7 +32,7 @@ def LayerNorm( ...@@ -22,7 +32,7 @@ def LayerNorm(
Args: Args:
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format. x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
epsilon (float): epsilon to avoid divide-by-zero. epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not. center, scale (bool): whether to use the extra affine transformation or not.
""" """
data_format = get_data_format(data_format, keras_mode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
...@@ -40,15 +50,13 @@ def LayerNorm( ...@@ -40,15 +50,13 @@ def LayerNorm(
if ndims == 2: if ndims == 2:
new_shape = [1, chan] new_shape = [1, chan]
if use_bias: if center:
beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer()) beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape) beta = tf.reshape(beta, new_shape)
else: else:
beta = tf.zeros([1] * ndims, name='beta') beta = tf.zeros([1] * ndims, name='beta')
if use_scale: if scale:
if gamma_init is None: gamma = tf.get_variable('gamma', [chan], initializer=gamma_initializer)
gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [chan], initializer=gamma_init)
gamma = tf.reshape(gamma, new_shape) gamma = tf.reshape(gamma, new_shape)
else: else:
gamma = tf.ones([1] * ndims, name='gamma') gamma = tf.ones([1] * ndims, name='gamma')
...@@ -56,15 +64,22 @@ def LayerNorm( ...@@ -56,15 +64,22 @@ def LayerNorm(
ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
vh = ret.variables = VariableHolder() vh = ret.variables = VariableHolder()
if use_scale: if scale:
vh.gamma = gamma vh.gamma = gamma
if use_bias: if center:
vh.beta = beta vh.beta = beta
return ret return ret
@layer_register() @layer_register()
def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'): @convert_to_tflayer_args(
args_names=[],
name_mapping={
'gamma_init': 'gamma_initializer',
})
def InstanceNorm(x, epsilon=1e-5, *, center=True, scale=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last', use_affine=None):
""" """
Instance Normalization, as in the paper: Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization `Instance Normalization: The Missing Ingredient for Fast Stylization
...@@ -73,12 +88,17 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format= ...@@ -73,12 +88,17 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
Args: Args:
x (tf.Tensor): a 4D tensor. x (tf.Tensor): a 4D tensor.
epsilon (float): avoid divide-by-zero epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation center, scale (bool): whether to use the extra affine transformation or not.
use_affine: deprecated. Don't use.
""" """
data_format = get_data_format(data_format, keras_mode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) == 4, "Input of InstanceNorm has to be 4D!" assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
if use_affine is not None:
log_deprecated("InstanceNorm(use_affine=)", "Use center= or scale= instead!", "2020-06-01")
center = scale = use_affine
if data_format == 'NHWC': if data_format == 'NHWC':
axis = [1, 2] axis = [1, 2]
ch = shape[3] ch = shape[3]
...@@ -91,19 +111,21 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format= ...@@ -91,19 +111,21 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
mean, var = tf.nn.moments(x, axis, keep_dims=True) mean, var = tf.nn.moments(x, axis, keep_dims=True)
if not use_affine: if center:
return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output') beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer()) else:
beta = tf.reshape(beta, new_shape) beta = tf.zeros([1, 1, 1, 1], name='beta', dtype=x.dtype)
if gamma_init is None: if scale:
gamma_init = tf.constant_initializer(1.0) gamma = tf.get_variable('gamma', [ch], initializer=gamma_initializer)
gamma = tf.get_variable('gamma', [ch], initializer=gamma_init) gamma = tf.reshape(gamma, new_shape)
gamma = tf.reshape(gamma, new_shape) else:
gamma = tf.ones([1, 1, 1, 1], name='gamma', dtype=x.dtype)
ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
vh = ret.variables = VariableHolder() vh = ret.variables = VariableHolder()
if use_affine: if scale:
vh.gamma = gamma vh.gamma = gamma
if center:
vh.beta = beta vh.beta = beta
return ret return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment