Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
20d1af11
Commit
20d1af11
authored
Feb 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simplify ModeSaver. just save without rename
parent
3e97f126
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
31 deletions
+5
-31
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+1
-17
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-4
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+2
-10
No files found.
tensorpack/callbacks/saver.py
View file @
20d1af11
...
...
@@ -8,7 +8,6 @@ import shutil
from
.base
import
Triggerable
from
..utils
import
logger
from
..tfutils.varmanip
import
get_savename_from_varname
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
...
@@ -43,27 +42,12 @@ class ModelSaver(Triggerable):
vars
.
extend
(
tf
.
get_collection
(
key
))
self
.
path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
'model'
)
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
ModelSaver
.
_get_var_dict
(
vars
)
,
var_list
=
vars
,
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
,
write_version
=
tf
.
train
.
SaverDef
.
V2
)
self
.
meta_graph_written
=
False
@
staticmethod
def
_get_var_dict
(
vars
):
var_dict
=
{}
for
v
in
vars
:
name
=
get_savename_from_varname
(
v
.
name
)
if
name
not
in
var_dict
:
if
name
!=
v
.
name
:
logger
.
info
(
"[ModelSaver] {} renamed to {} when saving model."
.
format
(
v
.
name
,
name
))
var_dict
[
name
]
=
v
else
:
logger
.
info
(
"[ModelSaver] Variable {} won't be saved
\
due to an alternative in a different tower"
.
format
(
v
.
name
,
var_dict
[
name
]
.
name
))
return
var_dict
def
_trigger
(
self
):
try
:
if
not
self
.
meta_graph_written
:
...
...
tensorpack/tfutils/sessinit.py
View file @
20d1af11
...
...
@@ -143,7 +143,7 @@ class ParamRestore(SessionInit):
def
_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
# TODO
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
variable_names
=
set
([
k
.
name
for
k
in
variables
])
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
intersect
=
variable_names
&
param_names
...
...
@@ -156,9 +156,7 @@ class ParamRestore(SessionInit):
for
k
in
sorted
(
param_names
-
variable_names
):
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
get_savename_from_varname
(
v
.
name
)
in
intersect
])
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
logger
.
info
(
"Restoring from dict ..."
)
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
...
...
tensorpack/tfutils/varmanip.py
View file @
20d1af11
...
...
@@ -7,7 +7,6 @@ import six
import
os
import
tensorflow
as
tf
from
collections
import
defaultdict
import
re
import
numpy
as
np
from
..utils
import
logger
from
..utils.naming
import
PREDICT_TOWER
...
...
@@ -34,8 +33,6 @@ def get_savename_from_varname(
logger
.
error
(
"No variable under '{}' name scope should be saved!"
.
format
(
PREDICT_TOWER
))
# don't overwrite anything in the current prediction graph
return
None
if
'tower'
in
name
:
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
if
varname_prefix
is
not
None
\
and
name
.
startswith
(
varname_prefix
):
name
=
name
[
len
(
varname_prefix
)
+
1
:]
...
...
@@ -56,8 +53,7 @@ class SessionUpdate(object):
self
.
sess
=
sess
self
.
name_map
=
defaultdict
(
list
)
for
v
in
vars_to_update
:
savename
=
get_savename_from_varname
(
v
.
name
)
self
.
name_map
[
savename
]
.
append
(
v
)
self
.
name_map
[
v
.
name
]
.
append
(
v
)
@
staticmethod
def
load_value_to_var
(
var
,
val
,
strict
=
False
):
...
...
@@ -133,11 +129,7 @@ def dump_session_params(path):
assert
len
(
set
(
var
))
==
len
(
var
),
"TRAINABLE and MODEL variables have duplication!"
result
=
{}
for
v
in
var
:
name
=
get_savename_from_varname
(
v
.
name
)
if
name
in
result
:
logger
.
info
(
"Variable {} would be stored instead of another with
\
the same name"
.
format
(
v
.
name
))
result
[
name
]
=
v
.
eval
()
result
[
v
.
name
]
=
v
.
eval
()
logger
.
info
(
"Variables to save to {}:"
.
format
(
path
))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
path
,
result
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment