Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f603636c
Commit
f603636c
authored
Jul 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
group warning messages in SessionInit
parent
019ff1a5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
4 deletions
+29
-4
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+29
-4
No files found.
tensorpack/tfutils/sessinit.py
View file @
f603636c
...
...
@@ -78,6 +78,21 @@ class CheckpointReaderAdapter(object):
return
name
[:
-
2
]
class
MismatchLogger
(
object
):
def
__init__
(
self
,
exists
,
nonexists
):
self
.
_exists
=
exists
self
.
_nonexists
=
nonexists
self
.
_names
=
[]
def
add
(
self
,
name
):
self
.
_names
.
append
(
name
)
def
log
(
self
):
if
len
(
self
.
_names
):
logger
.
warn
(
"The following variables are in the {}, but not found in the {}: {}"
.
format
(
self
.
_exists
,
self
.
_nonexists
,
', '
.
join
(
self
.
_names
)))
class
SaverRestore
(
SessionInit
):
"""
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
...
...
@@ -114,6 +129,8 @@ class SaverRestore(SessionInit):
reader
,
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
graph_vars
=
tf
.
global_variables
()
chkpt_vars_used
=
set
()
mismatch
=
MismatchLogger
(
'graph'
,
'checkpoint'
)
for
v
in
graph_vars
:
name
=
get_savename_from_varname
(
v
.
name
,
varname_prefix
=
self
.
prefix
)
if
name
in
self
.
ignore
and
reader
.
has_tensor
(
name
):
...
...
@@ -125,12 +142,15 @@ class SaverRestore(SessionInit):
else
:
vname
=
v
.
op
.
name
if
not
is_training_name
(
vname
):
logger
.
warn
(
"Variable {} in the graph not found in checkpoint!"
.
format
(
vname
))
mismatch
.
add
(
vname
)
mismatch
.
log
()
mismatch
=
MismatchLogger
(
'checkpoint'
,
'graph'
)
if
len
(
chkpt_vars_used
)
<
len
(
chkpt_vars
):
unused
=
chkpt_vars
-
chkpt_vars_used
for
name
in
sorted
(
unused
):
if
not
is_training_name
(
name
):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
mismatch
.
add
(
name
)
mismatch
.
log
()
def
_get_restore_dict
(
self
):
var_dict
=
{}
...
...
@@ -185,11 +205,16 @@ class DictRestore(SessionInit):
logger
.
info
(
"Params to restore: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
mismatch
=
MismatchLogger
(
'graph'
,
'dict'
)
for
k
in
sorted
(
variable_names
-
param_names
):
if
not
is_training_name
(
k
):
logger
.
warn
(
"Variable {} in the graph not found in the dict!"
.
format
(
k
))
mismatch
.
add
(
k
)
mismatch
.
log
()
mismatch
=
MismatchLogger
(
'dict'
,
'graph'
)
for
k
in
sorted
(
param_names
-
variable_names
):
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
mismatch
.
add
(
k
)
mismatch
.
log
()
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
logger
.
info
(
"Restoring from dict ..."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment