Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5cccf2b8
Commit
5cccf2b8
authored
Nov 26, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
compatible with old version
parent
dc378b53
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
4 deletions
+7
-4
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+1
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+1
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+5
-2
No files found.
tensorpack/callbacks/common.py
View file @
5cccf2b8
...
@@ -18,7 +18,7 @@ class ModelSaver(Callback):
...
@@ -18,7 +18,7 @@ class ModelSaver(Callback):
Save the model to logger directory.
Save the model to logger directory.
"""
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
tf
.
GraphKeys
.
GLOBAL_
VARIABLES
):
var_collections
=
tf
.
GraphKeys
()
.
VARIABLES
):
"""
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
...
...
tensorpack/models/model_desc.py
View file @
5cccf2b8
...
@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc):
...
@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc):
tf
.
train
.
import_meta_graph
(
filename
)
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
for
k
in
[
INPUT_VARS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
for
k
in
[
INPUT_VARS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
GraphKeys
.
GLOBAL_
VARIABLES
]:
tf
.
GraphKeys
()
.
VARIABLES
]:
assert
k
in
all_coll
,
\
assert
k
in
all_coll
,
\
"Collection {} not found in metagraph!"
.
format
(
k
)
"Collection {} not found in metagraph!"
.
format
(
k
)
...
...
tensorpack/tfutils/sessinit.py
View file @
5cccf2b8
...
@@ -113,7 +113,10 @@ class SaverRestore(SessionInit):
...
@@ -113,7 +113,10 @@ class SaverRestore(SessionInit):
:param vars_available: varaible names available in the checkpoint, for existence checking
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
:returns: a dict of {var_name: [var, var]} to restore
"""
"""
vars_to_restore
=
tf
.
global_variables
()
try
:
vars_to_restore
=
tf
.
global_variables
()
except
AttributeError
:
vars_to_restore
=
tf
.
all_variables
()
var_dict
=
defaultdict
(
list
)
var_dict
=
defaultdict
(
list
)
chkpt_vars_used
=
set
()
chkpt_vars_used
=
set
()
for
v
in
vars_to_restore
:
for
v
in
vars_to_restore
:
...
@@ -150,7 +153,7 @@ class ParamRestore(SessionInit):
...
@@ -150,7 +153,7 @@ class ParamRestore(SessionInit):
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
()
.
VARIABLES
)
# TODO
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment