Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
73baa9ae
Commit
73baa9ae
authored
Jul 31, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
get_all_checkpoints
parent
425f9d27
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
13 deletions
+47
-13
docs/conf.py
docs/conf.py
+2
-0
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+45
-13
No files found.
docs/conf.py
View file @
73baa9ae
...
...
@@ -407,6 +407,8 @@ _DEPRECATED_NAMES = set([
"get_model_loader"
,
# renamed items that should not appear in docs
'load_chkpt_vars'
,
'save_chkpt_vars'
,
'DumpTensor'
,
'DumpParamAsImage'
,
'get_nr_gpu'
,
...
...
tensorpack/tfutils/varmanip.py
View file @
73baa9ae
# -*- coding: utf-8 -*-
# File: varmanip.py
import
glob
import
operator
import
numpy
as
np
import
os
import
pprint
...
...
@@ -12,7 +14,9 @@ from ..utils import logger
from
.common
import
get_op_tensor_name
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'load_chkpt_vars'
,
'save_chkpt_vars'
,
'get_checkpoint_path'
]
'load_chkpt_vars'
,
'save_chkpt_vars'
,
'load_checkpoint_vars'
,
'save_checkpoint_vars'
,
'get_checkpoint_path'
]
def
get_savename_from_varname
(
...
...
@@ -146,19 +150,19 @@ def dump_session_params(path):
path(str): the file name to save the parameters. Must ends with npz.
"""
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
var
=
tf
v1
.
get_collection
(
tfv1
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
v1
.
get_collection
(
tfv1
.
GraphKeys
.
MODEL_VARIABLES
))
# TODO dedup
assert
len
(
set
(
var
))
==
len
(
var
),
"TRAINABLE and MODEL variables have duplication!"
gvars
=
{
k
.
name
for
k
in
tf
.
global_variables
()}
gvars
=
{
k
.
name
for
k
in
tf
v1
.
global_variables
()}
var
=
[
v
for
v
in
var
if
v
.
name
in
gvars
]
result
=
{}
for
v
in
var
:
result
[
v
.
name
]
=
v
.
eval
()
save_ch
kp
t_vars
(
result
,
path
)
save_ch
eckpoin
t_vars
(
result
,
path
)
def
save_ch
kp
t_vars
(
dic
,
path
):
def
save_ch
eckpoin
t_vars
(
dic
,
path
):
"""
Save variables in dic to path.
...
...
@@ -174,13 +178,13 @@ def save_chkpt_vars(dic, path):
if
path
.
endswith
(
'.npz'
):
np
.
savez_compressed
(
path
,
**
dic
)
else
:
with
tf
.
Graph
()
.
as_default
(),
\
tf
.
Session
()
as
sess
:
with
tf
v1
.
Graph
()
.
as_default
(),
\
tf
v1
.
Session
()
as
sess
:
for
k
,
v
in
six
.
iteritems
(
dic
):
k
=
get_op_tensor_name
(
k
)[
0
]
_
=
tf
.
Variable
(
name
=
k
,
initial_value
=
v
)
# noqa
sess
.
run
(
tf
.
global_variables_initializer
())
saver
=
tf
.
train
.
Saver
()
_
=
tf
v1
.
Variable
(
name
=
k
,
initial_value
=
v
)
# noqa
sess
.
run
(
tf
v1
.
global_variables_initializer
())
saver
=
tf
v1
.
train
.
Saver
()
saver
.
save
(
sess
,
path
,
write_meta_graph
=
False
)
...
...
@@ -197,7 +201,7 @@ def get_checkpoint_path(path):
path
=
os
.
path
.
join
(
'.'
,
path
)
# avoid #4921 and #6142
if
os
.
path
.
basename
(
path
)
==
'checkpoint'
:
assert
tfv1
.
gfile
.
Exists
(
path
),
path
path
=
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
path
))
path
=
tf
v1
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
path
))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
...
...
@@ -214,7 +218,31 @@ def get_checkpoint_path(path):
return
path
def
load_chkpt_vars
(
path
):
def
get_all_checkpoints
(
dir
:
str
,
prefix
:
str
=
"model"
):
"""
Get a sorted list of all checkpoints found in directory.
Args:
dir (str): checkpoint directory
prefix (str): common prefix among all checkpoints (without the final "-")
Returns:
list[(str, int)]: list of (name, step) sorted by step.
Name is a checkpoint handle that can be passed to
`tf.train.NewCheckpointReader` or :func:`load_checkpoint_vars`.
"""
def
step_from_filename
(
name
):
name
=
os
.
path
.
basename
(
name
)
name
=
name
[
len
(
f
"{prefix}-"
):
-
len
(
".index"
)]
return
int
(
name
)
checkpoints
=
glob
.
glob
(
os
.
path
.
join
(
dir
,
"model-*.index"
))
checkpoints
=
[(
f
,
step_from_filename
(
f
))
for
f
in
checkpoints
]
checkpoints
=
sorted
(
checkpoints
,
key
=
operator
.
itemgetter
(
1
))
return
checkpoints
def
load_checkpoint_vars
(
path
):
""" Load all variables from a checkpoint to a dict.
Args:
...
...
@@ -257,3 +285,7 @@ def is_training_name(name):
if
name
.
startswith
(
'apply_gradients'
):
return
True
return
False
load_chkpt_vars
=
load_checkpoint_vars
save_chkpt_vars
=
save_checkpoint_vars
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment