Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
81bb9ac2
Commit
81bb9ac2
authored
Jun 18, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use a better bn variable name
parent
12d27154
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
5 deletions
+10
-5
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+6
-3
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+4
-2
No files found.
tensorpack/models/batch_norm.py
View file @
81bb9ac2
...
@@ -48,17 +48,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -48,17 +48,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
else
:
else
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
# just to make a clear name.
batch_mean
=
tf
.
identity
(
batch_mean
,
'mean'
)
batch_var
=
tf
.
identity
(
batch_var
,
'variance'
)
emaname
=
'EMA'
emaname
=
'EMA'
in_
tr
ain_tower
=
not
batch_mean
.
name
.
startswith
(
'towerp'
)
in_
m
ain_tower
=
not
batch_mean
.
name
.
startswith
(
'towerp'
)
if
in_
tr
ain_tower
:
if
in_
m
ain_tower
:
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
else
:
else
:
# use training-statistics in prediction
# use training-statistics in prediction
assert
not
use_local_stat
assert
not
use_local_stat
# have to do this again to get actual name. see issue:
#
XXX
have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
# https://github.com/tensorflow/tensorflow/issues/2740
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
...
...
tensorpack/tfutils/sessinit.py
View file @
81bb9ac2
...
@@ -132,7 +132,8 @@ class ParamRestore(SessionInit):
...
@@ -132,7 +132,8 @@ class ParamRestore(SessionInit):
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
tf
.
initialize_all_variables
())
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
# allow restore non-trainable variables
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
for
name
,
value
in
six
.
iteritems
(
self
.
prms
):
for
name
,
value
in
six
.
iteritems
(
self
.
prms
):
if
not
name
.
endswith
(
':0'
):
if
not
name
.
endswith
(
':0'
):
...
@@ -145,7 +146,8 @@ class ParamRestore(SessionInit):
...
@@ -145,7 +146,8 @@ class ParamRestore(SessionInit):
logger
.
info
(
"Restoring param {}"
.
format
(
name
))
logger
.
info
(
"Restoring param {}"
.
format
(
name
))
varshape
=
tuple
(
var
.
get_shape
()
.
as_list
())
varshape
=
tuple
(
var
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
if
varshape
!=
value
.
shape
:
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
)
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during loading!"
.
format
(
name
))
logger
.
warn
(
"Param {} is reshaped during loading!"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
value
=
value
.
reshape
(
varshape
)
sess
.
run
(
var
.
assign
(
value
))
sess
.
run
(
var
.
assign
(
value
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment