Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f5a1a67c
Commit
f5a1a67c
authored
Apr 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add layer normalization
parent
6d41928f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
3 deletions
+51
-3
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-2
tensorpack/models/layer_norm.py
tensorpack/models/layer_norm.py
+47
-0
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+2
-1
No files found.
tensorpack/models/batch_norm.py
View file @
f5a1a67c
...
...
@@ -140,7 +140,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
Args:
x (tf.Tensor): a
NHWC or NC tensor
.
x (tf.Tensor): a
4D or 2D tensor. When 4D, the layout should match data_format
.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference.
decay (float): decay rate of moving average.
...
...
@@ -202,7 +202,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
moving_mean
,
moving_var
,
epsilon
=
epsilon
,
is_training
=
False
,
data_format
=
data_format
)
else
:
xn
=
tf
.
nn
.
batch_normalization
(
xn
=
tf
.
nn
.
batch_normalization
(
# work only for NHWC when moving_mean is a vector
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
if
len
(
shape
)
==
2
:
...
...
tensorpack/models/layer_norm.py
0 → 100644
View file @
f5a1a67c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: layer_norm.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
.common
import
layer_register
@
layer_register
(
log_shape
=
False
)
def
LayerNorm
(
x
,
epsilon
=
1e-5
,
use_bias
=
True
,
use_scale
=
True
,
data_format
=
'NHWC'
):
"""
Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
Args:
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
"""
shape
=
x
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
mean
,
var
=
tf
.
nn
.
moments
(
x
,
list
(
range
(
1
,
len
(
shape
))),
keep_dims
=
True
)
if
data_format
==
'NCHW'
:
chan
=
shape
[
1
]
new_shape
=
[
1
,
chan
,
1
,
1
]
else
:
chan
=
shape
[
-
1
]
new_shape
=
[
1
,
1
,
1
,
chan
]
if
ndims
==
2
:
new_shape
=
[
1
,
chan
]
if
use_bias
:
beta
=
tf
.
get_variable
(
'beta'
,
[
chan
],
initializer
=
tf
.
constant_initializer
())
beta
=
tf
.
reshape
(
beta
,
new_shape
)
else
:
beta
=
tf
.
zeros
([
1
]
*
ndims
,
name
=
'beta'
)
if
use_scale
:
gamma
=
tf
.
get_variable
(
'gamma'
,
[
chan
],
initializer
=
tf
.
constant_initializer
(
1.0
))
gamma
=
tf
.
reshape
(
gamma
,
new_shape
)
else
:
gamma
=
tf
.
ones
([
1
]
*
ndims
,
name
=
'gamma'
)
return
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
tensorpack/tfutils/varmanip.py
View file @
f5a1a67c
...
...
@@ -75,7 +75,8 @@ class SessionUpdate(object):
# TODO only allow reshape when shape different by empty axis
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
val
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
val
.
shape
)
logger
.
warn
(
"Variable {} is reshaped during assigning"
.
format
(
name
))
logger
.
warn
(
"Variable {} is reshaped {}->{} during assigning"
.
format
(
name
,
val
.
shape
,
varshape
))
val
=
val
.
reshape
(
varshape
)
# fix some common type incompatibility problem, but is certainly not enough
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment