Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bde13a8d
Commit
bde13a8d
authored
Apr 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
always use non-fused op for BatchNorm inference; support NCHW for BatchRenorm
parent
f5a1a67c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
29 deletions
+57
-29
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+57
-29
No files found.
tensorpack/models/batch_norm.py
View file @
bde13a8d
...
...
@@ -130,6 +130,14 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
return
tf
.
identity
(
xn
,
name
=
'output'
)
def
reshape_for_bn
(
param
,
ndims
,
chan
,
data_format
):
if
ndims
==
2
:
shape
=
[
1
,
chan
]
else
:
shape
=
[
1
,
1
,
1
,
chan
]
if
data_format
==
'NHWC'
else
[
1
,
chan
,
1
,
1
]
return
tf
.
reshape
(
param
,
shape
)
@
layer_register
(
log_shape
=
False
)
def
BatchNorm
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
,
use_scale
=
True
,
use_bias
=
True
,
...
...
@@ -168,47 +176,48 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
with the official inceptionv3 example).
"""
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
in
[
2
,
4
]
if
len
(
shape
)
==
2
:
data_format
=
'NHWC'
# error using NCHW? (see #190)
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
if
ndims
==
2
:
data_format
=
'NHWC'
if
data_format
==
'NCHW'
:
n_out
=
shape
[
1
]
else
:
n_out
=
shape
[
-
1
]
# channel
if
len
(
shape
)
==
2
:
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
)
ctx
=
get_current_tower_context
()
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
if
use_local_stat
!=
ctx
.
is_training
:
el
if
use_local_stat
!=
ctx
.
is_training
:
# we allow the use of local_stat in testing (only print warnings)
# because it is useful to certain applications.
logger
.
warn
(
"[BatchNorm] use_local_stat != is_training"
)
if
use_local_stat
:
if
ndims
==
2
:
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
# fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
x
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
,
data_format
=
data_format
)
if
ndims
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
else
:
assert
not
ctx
.
is_training
,
"In training, local statistics has to be used!"
if
data_format
==
'NCHW'
:
# fused is slower in inference, but support NCHW
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
x
,
gamma
,
beta
,
moving_mean
,
moving_var
,
epsilon
=
epsilon
,
is_training
=
False
,
data_format
=
data_format
)
# non-fused op is faster for inference
if
ndims
==
4
and
data_format
==
'NCHW'
:
[
g
,
b
,
mm
,
mv
]
=
[
reshape_for_bn
(
_
,
ndims
,
n_out
,
data_format
)
for
_
in
[
gamma
,
beta
,
moving_mean
,
moving_var
]]
xn
=
tf
.
nn
.
batch_normalization
(
x
,
mm
,
mv
,
b
,
g
,
epsilon
)
else
:
xn
=
tf
.
nn
.
batch_normalization
(
# work only for NHWC when moving_mean is a vector
# avoid the reshape if possible (when channel is the last dimension)
xn
=
tf
.
nn
.
batch_normalization
(
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
if
len
(
shape
)
==
2
:
axis
=
[
2
,
3
]
if
data_format
==
'NCHW'
else
[
1
,
2
]
xn
=
tf
.
squeeze
(
xn
,
axis
)
# maintain EMA only on one GPU.
if
ctx
.
is_main_training_tower
:
return
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
)
...
...
@@ -219,7 +228,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
# TODO support NCHW
@
layer_register
(
log_shape
=
False
)
def
BatchRenorm
(
x
,
rmax
,
dmax
,
decay
=
0.9
,
epsilon
=
1e-5
,
use_scale
=
True
,
use_bias
=
True
):
use_scale
=
True
,
use_bias
=
True
,
data_format
=
'NHWC'
):
"""
Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
...
...
@@ -244,10 +253,16 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
"""
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
in
[
2
,
4
]
n_out
=
shape
[
-
1
]
if
len
(
shape
)
==
2
:
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
if
ndims
==
2
:
data_format
=
'NHWC'
# error using NCHW? (see #190)
if
data_format
==
'NCHW'
:
n_out
=
shape
[
1
]
else
:
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchRenorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
tf
.
constant_initializer
(
1.0
))
...
...
@@ -257,21 +272,34 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
# different usage comes out in the future.
if
use_local_stat
:
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
x
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
)
if
ndims
==
2
:
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
x
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
,
data_format
=
data_format
)
inv_sigma
=
tf
.
rsqrt
(
moving_var
,
'inv_sigma'
)
r
=
tf
.
stop_gradient
(
tf
.
clip_by_value
(
tf
.
sqrt
(
batch_var
)
*
inv_sigma
,
1.0
/
rmax
,
rmax
))
d
=
tf
.
stop_gradient
(
tf
.
clip_by_value
(
(
batch_mean
-
moving_mean
)
*
inv_sigma
,
-
dmax
,
dmax
))
r
=
reshape_for_bn
(
r
,
ndims
,
n_out
,
data_format
)
d
=
reshape_for_bn
(
d
,
ndims
,
n_out
,
data_format
)
xn
=
xn
*
r
+
d
if
ndims
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
else
:
if
ndims
==
4
and
data_format
==
'NCHW'
:
[
g
,
b
,
mm
,
mv
]
=
[
reshape_for_bn
(
_
,
ndims
,
n_out
,
data_format
)
for
_
in
[
gamma
,
beta
,
moving_mean
,
moving_var
]]
xn
=
tf
.
nn
.
batch_normalization
(
x
,
mm
,
mv
,
b
,
g
,
epsilon
)
else
:
xn
=
tf
.
nn
.
batch_normalization
(
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
if
len
(
shape
)
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
if
ctx
.
is_main_training_tower
:
return
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment