Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
12d27154
Commit
12d27154
authored
Jun 18, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use unbiased variance in training
parent
132dcccd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
9 deletions
+10
-9
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+10
-9
No files found.
tensorpack/models/batch_norm.py
View file @
12d27154
...
...
@@ -30,8 +30,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
* Epsilon for variance is set to 1e-5, as is `torch/nn <https://github.com/torch/nn/blob/master/BatchNormalization.lua>`_.
:param input: a NHWC tensor or a NC vector
:param use_local_stat: bool. whether to use mean/var of this batch or the running average.
Usually set to True in training and False in testing
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average. Set to True in training and False in testing
:param decay: decay rate. default to 0.999.
:param epsilon: default to 1e-5.
"""
...
...
@@ -40,9 +39,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
assert
len
(
shape
)
in
[
2
,
4
]
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
])
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
))
if
len
(
shape
)
==
2
:
...
...
@@ -51,7 +50,8 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
emaname
=
'EMA'
if
not
batch_mean
.
name
.
startswith
(
'towerp'
):
in_train_tower
=
not
batch_mean
.
name
.
startswith
(
'towerp'
)
if
in_train_tower
:
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
...
...
@@ -65,6 +65,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
G
=
tf
.
get_default_graph
()
# find training statistics in training tower
try
:
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_mean
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_var
.
name
)
...
...
@@ -81,11 +82,11 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
if
use_local_stat
:
with
tf
.
control_dependencies
([
ema_apply_op
]):
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
mul
=
tf
.
select
(
tf
.
equal
(
batch
,
1.0
),
1.0
,
batch
/
(
batch
-
1
))
batch_var
=
batch_var
*
mul
# use unbiased variance estimator in training
return
tf
.
nn
.
batch_normalization
(
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'bn'
)
else
:
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
# XXX TODO batch==1?
mean
,
var
=
ema_mean
,
ema_var
*
batch
/
(
batch
-
1
)
# unbiased variance estimator
return
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
'bn'
)
x
,
ema_mean
,
ema_
var
,
beta
,
gamma
,
epsilon
,
'bn'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment