Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ebf1d570
Commit
ebf1d570
authored
Aug 03, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix BatchRenorm (fix #360)
parent
8b487b90
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
45 deletions
+33
-45
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+32
-45
tensorpack/tfutils/collection.py
tensorpack/tfutils/collection.py
+1
-0
No files found.
tensorpack/models/batch_norm.py
View file @
ebf1d570
...
@@ -6,8 +6,10 @@
...
@@ -6,8 +6,10 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
tensorflow.python.training
import
moving_averages
from
tensorflow.python.layers.normalization
import
BatchNorm
as
TF_BatchNorm
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
..utils
import
logger
from
..utils
import
logger
from
.common
import
layer_register
,
VariableHolder
from
.common
import
layer_register
,
VariableHolder
...
@@ -31,7 +33,7 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
...
@@ -31,7 +33,7 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
initializer
=
tf
.
constant_initializer
(
1.0
),
trainable
=
False
)
return
beta
,
gamma
,
moving_mean
,
moving_var
return
beta
,
gamma
,
moving_mean
,
moving_var
...
@@ -179,8 +181,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
...
@@ -179,8 +181,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
* ``beta``: the bias term.
* ``beta``: the bias term.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
* ``m
ean/EMA``: the moving average of mea
n.
* ``m
oving_mean, renorm_mean, renorm_mean_weight``: See TF documentatio
n.
* ``
variance/EMA``: the moving average of variance
.
* ``
moving_variance, renorm_stddev, renorm_stddev_weight``: See TF documentation
.
"""
"""
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
...
@@ -188,59 +190,44 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
...
@@ -188,59 +190,44 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
assert
ndims
in
[
2
,
4
]
assert
ndims
in
[
2
,
4
]
if
ndims
==
2
:
if
ndims
==
2
:
data_format
=
'NHWC'
# error using NCHW? (see #190)
data_format
=
'NHWC'
# error using NCHW? (see #190)
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
shape
[
1
]])
if
data_format
==
'NCHW'
:
if
data_format
==
'NCHW'
:
n_out
=
shape
[
1
]
n_out
=
shape
[
1
]
else
:
else
:
n_out
=
shape
[
-
1
]
# channel
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchRenorm cannot have unknown channels!"
assert
n_out
is
not
None
,
"Input to BatchRenorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
tf
.
constant_initializer
(
1.0
))
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
use_local_stat
=
ctx
.
is_training
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
# for BatchRenorm, use_local_stat should always be is_training, unless a
layer
=
TF_BatchNorm
(
# different usage comes out in the future.
axis
=
1
if
data_format
==
'NCHW'
else
3
,
momentum
=
decay
,
epsilon
=
epsilon
,
if
use_local_stat
:
center
=
use_bias
,
scale
=
use_scale
,
if
ndims
==
2
:
renorm
=
True
,
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
renorm_clipping
=
{
'rmin'
:
1.0
/
rmax
,
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
'rmax'
:
rmax
,
x
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
,
data_format
=
data_format
)
'dmax'
:
dmax
},
renorm_momentum
=
0.99
,
inv_sigma
=
tf
.
rsqrt
(
moving_var
,
'inv_sigma'
)
fused
=
False
)
r
=
tf
.
stop_gradient
(
tf
.
clip_by_value
(
xn
=
layer
.
apply
(
x
,
training
=
ctx
.
is_training
,
scope
=
tf
.
get_variable_scope
())
tf
.
sqrt
(
batch_var
)
*
inv_sigma
,
1.0
/
rmax
,
rmax
))
d
=
tf
.
stop_gradient
(
tf
.
clip_by_value
(
(
batch_mean
-
moving_mean
)
*
inv_sigma
,
-
dmax
,
dmax
))
r
=
reshape_for_bn
(
r
,
ndims
,
n_out
,
data_format
)
d
=
reshape_for_bn
(
d
,
ndims
,
n_out
,
data_format
)
xn
=
xn
*
r
+
d
if
ndims
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
else
:
if
ndims
==
4
and
data_format
==
'NCHW'
:
[
g
,
b
,
mm
,
mv
]
=
[
reshape_for_bn
(
_
,
ndims
,
n_out
,
data_format
)
for
_
in
[
gamma
,
beta
,
moving_mean
,
moving_var
]]
xn
=
tf
.
nn
.
batch_normalization
(
x
,
mm
,
mv
,
b
,
g
,
epsilon
)
else
:
xn
=
tf
.
nn
.
batch_normalization
(
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
# training also needs EMA, so we should maintain it as long as there are
# corresponding EMA variables.
if
ctx
.
has_own_variables
:
if
ctx
.
has_own_variables
:
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
)
# only apply update in this case
for
v
in
layer
.
non_trainable_variables
:
add_model_variable
(
v
)
else
:
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
# don't need update if we are sharing variables from an old tower
restore_collection
(
coll_bk
)
vh
=
ret
.
variables
=
VariableHolder
(
mean
=
moving_mean
,
variance
=
moving_var
)
if
ndims
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
# TODO not sure whether to add moving_mean/moving_var to VH now
vh
=
ret
.
variables
=
VariableHolder
()
if
use_scale
:
if
use_scale
:
vh
.
gamma
=
gamma
vh
.
gamma
=
layer
.
gamma
if
use_bias
:
if
use_bias
:
vh
.
beta
=
beta
vh
.
beta
=
layer
.
beta
return
ret
return
ret
tensorpack/tfutils/collection.py
View file @
ebf1d570
...
@@ -22,6 +22,7 @@ def backup_collection(keys):
...
@@ -22,6 +22,7 @@ def backup_collection(keys):
dict: the backup
dict: the backup
"""
"""
ret
=
{}
ret
=
{}
assert
isinstance
(
keys
,
(
list
,
tuple
))
for
k
in
keys
:
for
k
in
keys
:
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
return
ret
return
ret
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment