Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4be41492
Commit
4be41492
authored
Feb 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simplify code of two saverrestore
parent
cc89b105
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
23 deletions
+15
-23
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+15
-23
No files found.
tensorpack/tfutils/sessinit.py
View file @
4be41492
...
...
@@ -118,17 +118,14 @@ class SaverRestore(SessionInit):
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
return
reader
,
set
(
ckpt_vars
)
def
_
get_restore_dict
(
self
):
def
_
match_vars
(
self
,
func
):
reader
,
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
graph_vars
=
tf
.
global_variables
()
var_dict
=
{}
chkpt_vars_used
=
set
()
for
v
in
graph_vars
:
name
=
get_savename_from_varname
(
v
.
name
,
varname_prefix
=
self
.
prefix
)
if
reader
.
has_tensor
(
name
):
ckpt_name
=
reader
.
get_real_name
(
name
)
assert
ckpt_name
not
in
var_dict
,
"Restore conflict: {} and {}"
.
format
(
v
.
name
,
var_dict
[
ckpt_name
]
.
name
)
var_dict
[
ckpt_name
]
=
v
func
(
reader
,
name
,
v
)
chkpt_vars_used
.
add
(
name
)
else
:
vname
=
v
.
op
.
name
...
...
@@ -139,6 +136,15 @@ class SaverRestore(SessionInit):
for
name
in
sorted
(
unused
):
if
not
is_training_name
(
name
):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
def
_get_restore_dict
(
self
):
var_dict
=
{}
def
f
(
reader
,
name
,
v
):
name
=
reader
.
get_real_name
(
name
)
assert
name
not
in
var_dict
,
"Restore conflict: {} and {}"
.
format
(
v
.
name
,
var_dict
[
name
]
.
name
)
var_dict
[
name
]
=
v
self
.
_match_vars
(
f
)
return
var_dict
...
...
@@ -153,26 +159,12 @@ class SaverRestoreRelaxed(SaverRestore):
def
_run_init
(
self
,
sess
):
logger
.
info
(
"Restoring checkpoint from {} ..."
.
format
(
self
.
path
))
reader
,
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
graph_vars
=
tf
.
global_variables
()
chkpt_vars_used
=
set
()
def
f
(
reader
,
name
,
v
):
val
=
reader
.
get_tensor
(
name
)
SessionUpdate
.
load_value_to_var
(
v
,
val
)
with
sess
.
as_default
():
for
v
in
graph_vars
:
name
=
get_savename_from_varname
(
v
.
name
,
varname_prefix
=
self
.
prefix
)
if
name
in
chkpt_vars
:
val
=
reader
.
get_tensor
(
name
)
SessionUpdate
.
load_value_to_var
(
v
,
val
)
chkpt_vars_used
.
add
(
name
)
else
:
vname
=
v
.
op
.
name
if
not
is_training_name
(
vname
):
logger
.
warn
(
"Variable {} in the graph not found in checkpoint!"
.
format
(
vname
))
if
len
(
chkpt_vars_used
)
<
len
(
chkpt_vars
):
unused
=
chkpt_vars
-
chkpt_vars_used
for
name
in
sorted
(
unused
):
if
not
is_training_name
(
name
):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
self
.
_match_vars
(
f
)
class
ParamRestore
(
SessionInit
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment