Commit f5a1a67c authored by Yuxin Wu's avatar Yuxin Wu

add layer normalization

parent 6d41928f
...@@ -140,7 +140,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -140,7 +140,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_. Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
Args: Args:
x (tf.Tensor): a NHWC or NC tensor. x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average. use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference. Defaults to True in training and False in inference.
decay (float): decay rate of moving average. decay (float): decay rate of moving average.
...@@ -202,7 +202,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -202,7 +202,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
moving_mean, moving_var, moving_mean, moving_var,
epsilon=epsilon, is_training=False, data_format=data_format) epsilon=epsilon, is_training=False, data_format=data_format)
else: else:
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization( # work only for NHWC when moving_mean is a vector
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
if len(shape) == 2: if len(shape) == 2:
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: layer_norm.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from .common import layer_register
@layer_register(log_shape=False)
def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'):
"""
Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
Args:
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
"""
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
mean, var = tf.nn.moments(x, list(range(1, len(shape))), keep_dims=True)
if data_format == 'NCHW':
chan = shape[1]
new_shape = [1, chan, 1, 1]
else:
chan = shape[-1]
new_shape = [1, 1, 1, chan]
if ndims == 2:
new_shape = [1, chan]
if use_bias:
beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
else:
beta = tf.zeros([1] * ndims, name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [chan], initializer=tf.constant_initializer(1.0))
gamma = tf.reshape(gamma, new_shape)
else:
gamma = tf.ones([1] * ndims, name='gamma')
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
...@@ -75,7 +75,8 @@ class SessionUpdate(object): ...@@ -75,7 +75,8 @@ class SessionUpdate(object):
# TODO only allow reshape when shape different by empty axis # TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(val.shape), \ assert np.prod(varshape) == np.prod(val.shape), \
"{}: {}!={}".format(name, varshape, val.shape) "{}: {}!={}".format(name, varshape, val.shape)
logger.warn("Variable {} is reshaped during assigning".format(name)) logger.warn("Variable {} is reshaped {}->{} during assigning".format(
name, val.shape, varshape))
val = val.reshape(varshape) val = val.reshape(varshape)
# fix some common type incompatibility problem, but is certainly not enough # fix some common type incompatibility problem, but is certainly not enough
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment