Commit 54eebb3b authored by Yuxin Wu's avatar Yuxin Wu

gamma_init in InstanceNorm & LayerNorm

parent bc3f8353
......@@ -10,7 +10,10 @@ __all__ = ['LayerNorm', 'InstanceNorm']
@layer_register()
def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'):
def LayerNorm(
x, epsilon=1e-5,
use_bias=True, use_scale=True,
gamma_init=None, data_format='NHWC'):
"""
Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
......@@ -41,7 +44,9 @@ def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'
else:
beta = tf.zeros([1] * ndims, name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [chan], initializer=tf.constant_initializer(1.0))
if gamma_init is None:
gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [chan], initializer=gamma_init)
gamma = tf.reshape(gamma, new_shape)
else:
gamma = tf.ones([1] * ndims, name='gamma')
......@@ -50,7 +55,7 @@ def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'
@layer_register()
def InstanceNorm(x, epsilon=1e-5, data_format='NHWC', use_affine=True):
def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='NHWC'):
"""
Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization
......@@ -81,6 +86,8 @@ def InstanceNorm(x, epsilon=1e-5, data_format='NHWC', use_affine=True):
beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0))
if gamma_init is None:
gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
gamma = tf.reshape(gamma, new_shape)
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
......@@ -44,18 +44,18 @@ class LinearWrap(object):
# this is a registered tensorpack layer
# parse arguments by tensorpack model convention
if layer.use_scope:
def f(name, *args, **kwargs):
def layer_func(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret)
else:
def f(*args, **kwargs):
def layer_func(*args, **kwargs):
if len(args) and isinstance(args[0], six.string_types):
name, args = args[0], args[1:]
ret = layer(name, self._t, *args, **kwargs)
else:
ret = layer(self._t, *args, **kwargs)
return LinearWrap(ret)
return f
return layer_func
else:
assert layer_name == 'tf', \
"Calling LinearWrap.{}:" \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment