Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
35f24d40
Commit
35f24d40
authored
Aug 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
varmanip
parent
f8d4352a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
24 deletions
+33
-24
scripts/dump-model-params.py
scripts/dump-model-params.py
+3
-4
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-18
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+28
-2
No files found.
scripts/dump-model-params.py
View file @
35f24d40
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dump
_model_
params.py
# File: dump
-model-
params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
argparse
import
cv2
import
tensorflow
as
tf
import
imp
from
tensorpack.utils
import
*
from
tensorpack.tfutils
import
sessinit
from
tensorpack.tfutils
import
sessinit
,
varmanip
from
tensorpack.dataflow
import
*
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -27,4 +26,4 @@ with tf.Graph().as_default() as G:
sess
=
tf
.
Session
()
init
.
init
(
sess
)
with
sess
.
as_default
():
sessinit
.
dump_session_params
(
args
.
output
)
varmanip
.
dump_session_params
(
args
.
output
)
tensorpack/tfutils/sessinit.py
View file @
35f24d40
...
...
@@ -4,7 +4,6 @@
import
os
from
abc
import
abstractmethod
,
ABCMeta
import
numpy
as
np
from
collections
import
defaultdict
import
re
import
tensorflow
as
tf
...
...
@@ -12,12 +11,11 @@ import six
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
.common
import
get_op_var_name
from
.
sessupdate
import
SessionUpdate
from
.
varmanip
import
SessionUpdate
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'ParamRestore'
,
'ChainInit'
,
'JustCurrentSession'
,
'dump_session_params'
]
'JustCurrentSession'
]
# TODO they initialize_all at the beginning by default.
...
...
@@ -180,17 +178,3 @@ def ChainInit(SessionInit):
def
_init
(
self
,
sess
):
for
i
in
self
.
inits
:
i
.
init
(
sess
)
def
dump_session_params
(
path
):
""" Dump value of all trainable variables to a dict and save to `path` as
npy format, loadable by ParamRestore
"""
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
EXTRA_SAVE_VARS_KEY
))
result
=
{}
for
v
in
var
:
name
=
v
.
name
.
replace
(
":0"
,
""
)
result
[
name
]
=
v
.
eval
()
logger
.
info
(
"Variables to save to {}:"
.
format
(
path
))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
path
,
result
)
tensorpack/tfutils/
sessupdate
.py
→
tensorpack/tfutils/
varmanip
.py
View file @
35f24d40
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File:
sessupdate
.py
# File:
varmanip
.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
six
import
tensorflow
as
tf
import
numpy
as
np
__all__
=
[
'SessionUpdate'
]
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
]
class
SessionUpdate
(
object
):
""" Update the variables in a session """
...
...
@@ -35,3 +36,28 @@ class SessionUpdate(object):
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
def
dump_session_params
(
path
):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
"""
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
EXTRA_SAVE_VARS_KEY
))
result
=
{}
for
v
in
var
:
name
=
v
.
name
.
replace
(
":0"
,
""
)
result
[
name
]
=
v
.
eval
()
logger
.
info
(
"Variables to save to {}:"
.
format
(
path
))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
path
,
result
)
def
dump_chkpt_vars
(
model_path
,
output
):
""" Dump all variables from a checkpoint """
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
var_names
=
reader
.
get_variable_to_shape_map
()
.
keys
()
result
=
{}
for
n
in
var_names
:
result
[
n
]
=
reader
.
get_tensor
(
n
)
logger
.
info
(
"Variables to save to {}:"
.
format
(
output
))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
output
,
result
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment