Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f1d15364
Commit
f1d15364
authored
Dec 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add BNV2 which uses fused_batch_norm
parent
78ccd295
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
3 deletions
+75
-3
examples/GAN/README.md
examples/GAN/README.md
+1
-0
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+74
-3
No files found.
examples/GAN/README.md
View file @
f1d15364
...
@@ -34,4 +34,5 @@ It requires the datasets released by the original authors.
...
@@ -34,4 +34,5 @@ It requires the datasets released by the original authors.
Reproduce a mnist experiement in InfoGAN.
Reproduce a mnist experiement in InfoGAN.
By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information,
By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information,
the network learns to map the 10 variables to 10 digits in a completely unsupervised way.
the network learns to map the 10 variables to 10 digits in a completely unsupervised way.


tensorpack/models/batch_norm.py
View file @
f1d15364
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
copy
import
copy
from
copy
import
copy
import
re
import
re
...
@@ -12,13 +13,12 @@ from ..tfutils.tower import get_current_tower_context
...
@@ -12,13 +13,12 @@ from ..tfutils.tower import get_current_tower_context
from
..utils
import
logger
from
..utils
import
logger
from
._common
import
layer_register
from
._common
import
layer_register
__all__
=
[
'BatchNorm'
]
__all__
=
[
'BatchNorm'
,
'BatchNormV1'
,
'BatchNormV2'
]
# TF batch_norm only works for 4D tensor right now: #804
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
# eps: torch: 1e-5. Lasagne: 1e-4
@
layer_register
(
log_shape
=
False
)
@
layer_register
(
log_shape
=
False
)
def
BatchNorm
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
def
BatchNorm
V1
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
"""
Batch normalization layer as described in:
Batch normalization layer as described in:
...
@@ -107,3 +107,74 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -107,3 +107,74 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else
:
else
:
return
tf
.
nn
.
batch_normalization
(
return
tf
.
nn
.
batch_normalization
(
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
@
layer_register
(
log_shape
=
False
)
def
BatchNormV2
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
Batch normalization layer as described in:
`Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
:param input: a NHWC or NC tensor
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average.
Default to True in training and False in inference.
:param decay: decay rate. default to 0.9.
:param epsilon: default to 1e-5.
"""
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
in
[
2
,
4
]
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
if
len
(
shape
)
==
2
:
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
zeros_initializer
)
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
))
# x * gamma + beta
ctx
=
get_current_tower_context
()
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
if
use_local_stat
!=
ctx
.
is_training
:
logger
.
warn
(
"[BatchNorm] use_local_stat != is_training"
)
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
initializer
=
tf
.
zeros_initializer
,
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
zeros_initializer
,
trainable
=
False
)
if
use_local_stat
:
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
x
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
ctx
.
is_training
)
if
ctx
.
is_training
:
# maintain EMA if training
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
update_op2
=
moving_averages
.
assign_moving_average
(
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
if
ctx
.
is_main_training_tower
:
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_var
)
else
:
assert
not
ctx
.
is_training
,
"In training, local statistics has to be used!"
# TODO do I need to add_model_variable.
# assume some fixed-param tasks, such as load model and fine tune one layer
# fused is slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#moving_mean, moving_var,
#epsilon=epsilon, is_training=False, name='output')
xn
=
tf
.
nn
.
batch_normalization
(
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
if
ctx
.
is_training
:
with
tf
.
control_dependencies
([
update_op1
,
update_op2
]):
return
tf
.
identity
(
xn
,
name
=
'output'
)
else
:
return
tf
.
identity
(
xn
,
name
=
'output'
)
BatchNorm
=
BatchNormV2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment