Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b9498a1a
Commit
b9498a1a
authored
Nov 06, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve logging on model loading
parent
78c7488e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
9 deletions
+35
-9
tensorpack/models/pool.py
tensorpack/models/pool.py
+1
-1
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+1
-0
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+6
-4
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+3
-2
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+24
-2
No files found.
tensorpack/models/pool.py
View file @
b9498a1a
...
...
@@ -3,7 +3,7 @@
# File: pool.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
numpy
import
numpy
as
np
from
._common
import
layer_register
,
shape2d
,
shape4d
from
..tfutils
import
symbolic_functions
as
symbf
...
...
tensorpack/tfutils/modelutils.py
View file @
b9498a1a
...
...
@@ -39,3 +39,4 @@ def get_shape_str(tensors):
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
return
shape_str
tensorpack/tfutils/sessinit.py
View file @
b9498a1a
...
...
@@ -10,9 +10,9 @@ import numpy as np
import
tensorflow
as
tf
import
six
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
..utils
import
logger
from
.common
import
get_op_var_name
from
.varmanip
import
SessionUpdate
,
get_savename_from_varname
from
.varmanip
import
SessionUpdate
,
get_savename_from_varname
,
is_training_specific_name
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'ParamRestore'
,
'ChainInit'
,
...
...
@@ -127,7 +127,8 @@ class SaverRestore(SessionInit):
var_dict
[
name
]
.
append
(
v
)
chkpt_vars_used
.
add
(
name
)
else
:
logger
.
warn
(
"Variable {} in the graph not found in checkpoint!"
.
format
(
v
.
op
.
name
))
if
not
is_training_specific_name
(
v
.
op
.
name
):
logger
.
warn
(
"Variable {} in the graph not found in checkpoint!"
.
format
(
v
.
op
.
name
))
if
len
(
chkpt_vars_used
)
<
len
(
vars_available
):
unused
=
vars_available
-
chkpt_vars_used
for
name
in
unused
:
...
...
@@ -156,7 +157,8 @@ class ParamRestore(SessionInit):
logger
.
info
(
"Params to restore: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
for
k
in
variable_names
-
param_names
:
logger
.
warn
(
"Variable {} in the graph not found in the dict!"
.
format
(
k
))
if
not
is_training_specific_name
(
k
):
logger
.
warn
(
"Variable {} in the graph not found in the dict!"
.
format
(
k
))
for
k
in
param_names
-
variable_names
:
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
...
...
tensorpack/tfutils/summary.py
View file @
b9498a1a
...
...
@@ -106,12 +106,13 @@ def summary_moving_average():
:returns: a op to maintain these average.
"""
with
tf
.
name_scope
(
'EMA_summary'
):
# TODO will produce EMA_summary/tower0/xxx. not elegant
global_step_var
=
get_global_step_var
()
with
tf
.
name_scope
(
None
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
0.99
,
num_updates
=
global_step_var
,
name
=
'EMA'
)
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
for
idx
,
c
in
enumerate
(
vars_to_summary
):
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
...
...
tensorpack/tfutils/varmanip.py
View file @
b9498a1a
...
...
@@ -10,9 +10,10 @@ import re
import
numpy
as
np
from
..utils
import
logger
from
..utils.naming
import
*
from
.common
import
get_op_tensor_name
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
'get_savename_from_varname'
]
'get_savename_from_varname'
,
'is_training_specific_name'
]
def
get_savename_from_varname
(
varname
,
varname_prefix
=
None
,
...
...
@@ -24,7 +25,7 @@ def get_savename_from_varname(
:returns: the name used to save the variable
"""
name
=
varname
if
'towerp'
in
name
:
if
'towerp
/
'
in
name
:
logger
.
error
(
"No variable should be under 'towerp' name scope"
.
format
(
v
.
name
))
# don't overwrite anything in the current prediction graph
return
None
...
...
@@ -95,3 +96,24 @@ def dump_chkpt_vars(model_path):
for
n
in
var_names
:
result
[
n
]
=
reader
.
get_tensor
(
n
)
return
result
def
is_training_specific_name
(
name
):
"""
This is only used to improve logging.
:returns: guess whether this tensor is something only used in training.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and EXTRA_SAVE_VARS_KEY ?
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
.
endswith
(
'/Adam'
)
or
name
.
endswith
(
'/Adam_1'
):
return
True
if
name
.
endswith
(
'/Momentum'
):
return
True
if
name
.
endswith
(
'/Adadelta'
)
or
name
.
endswith
(
'/Adadelta_1'
):
return
True
if
name
.
endswith
(
'/RMSProp'
)
or
name
.
endswith
(
'/RMSProp_1'
):
return
True
if
name
.
endswith
(
'/Adagrad'
):
return
True
if
'EMA_summary/'
in
name
:
return
True
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment