Commit 54eebb3b authored by Yuxin Wu's avatar Yuxin Wu

gamma_init in InstanceNorm & LayerNorm

parent bc3f8353
...@@ -10,7 +10,10 @@ __all__ = ['LayerNorm', 'InstanceNorm'] ...@@ -10,7 +10,10 @@ __all__ = ['LayerNorm', 'InstanceNorm']
@layer_register() @layer_register()
def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'): def LayerNorm(
x, epsilon=1e-5,
use_bias=True, use_scale=True,
gamma_init=None, data_format='NHWC'):
""" """
Layer Normalization layer, as described in the paper: Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_. `Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
...@@ -41,7 +44,9 @@ def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC' ...@@ -41,7 +44,9 @@ def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'
else: else:
beta = tf.zeros([1] * ndims, name='beta') beta = tf.zeros([1] * ndims, name='beta')
if use_scale: if use_scale:
gamma = tf.get_variable('gamma', [chan], initializer=tf.constant_initializer(1.0)) if gamma_init is None:
gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [chan], initializer=gamma_init)
gamma = tf.reshape(gamma, new_shape) gamma = tf.reshape(gamma, new_shape)
else: else:
gamma = tf.ones([1] * ndims, name='gamma') gamma = tf.ones([1] * ndims, name='gamma')
...@@ -50,7 +55,7 @@ def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC' ...@@ -50,7 +55,7 @@ def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'
@layer_register() @layer_register()
def InstanceNorm(x, epsilon=1e-5, data_format='NHWC', use_affine=True): def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='NHWC'):
""" """
Instance Normalization, as in the paper: Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization `Instance Normalization: The Missing Ingredient for Fast Stylization
...@@ -81,6 +86,8 @@ def InstanceNorm(x, epsilon=1e-5, data_format='NHWC', use_affine=True): ...@@ -81,6 +86,8 @@ def InstanceNorm(x, epsilon=1e-5, data_format='NHWC', use_affine=True):
beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer()) beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape) beta = tf.reshape(beta, new_shape)
gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0)) if gamma_init is None:
gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
gamma = tf.reshape(gamma, new_shape) gamma = tf.reshape(gamma, new_shape)
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
...@@ -44,18 +44,18 @@ class LinearWrap(object): ...@@ -44,18 +44,18 @@ class LinearWrap(object):
# this is a registered tensorpack layer # this is a registered tensorpack layer
# parse arguments by tensorpack model convention # parse arguments by tensorpack model convention
if layer.use_scope: if layer.use_scope:
def f(name, *args, **kwargs): def layer_func(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs) ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret) return LinearWrap(ret)
else: else:
def f(*args, **kwargs): def layer_func(*args, **kwargs):
if len(args) and isinstance(args[0], six.string_types): if len(args) and isinstance(args[0], six.string_types):
name, args = args[0], args[1:] name, args = args[0], args[1:]
ret = layer(name, self._t, *args, **kwargs) ret = layer(name, self._t, *args, **kwargs)
else: else:
ret = layer(self._t, *args, **kwargs) ret = layer(self._t, *args, **kwargs)
return LinearWrap(ret) return LinearWrap(ret)
return f return layer_func
else: else:
assert layer_name == 'tf', \ assert layer_name == 'tf', \
"Calling LinearWrap.{}:" \ "Calling LinearWrap.{}:" \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment