Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8b4d4f77
Commit
8b4d4f77
authored
Apr 05, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dump-model without GPU. some checks on windows support
parent
c8a9e4e5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
33 deletions
+54
-33
scripts/dump-model-params.py
scripts/dump-model-params.py
+26
-27
tensorpack/callbacks/prof.py
tensorpack/callbacks/prof.py
+1
-0
tensorpack/dataflow/parallel.py
tensorpack/dataflow/parallel.py
+2
-0
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+25
-6
No files found.
scripts/dump-model-params.py
View file @
8b4d4f77
...
...
@@ -3,11 +3,14 @@
# File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
six
import
argparse
import
os
import
tensorflow
as
tf
from
tensorpack
import
logger
from
tensorpack.tfutils
import
varmanip
,
get_model_loader
from
tensorpack
.tfutils
import
varmanip
from
tensorpack.tfutils
.common
import
get_op_tensor_name
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
...
...
@@ -17,31 +20,27 @@ if __name__ == '__main__':
parser
.
add_argument
(
dest
=
'output'
,
help
=
'output model file, can be npz or TF checkpoint'
)
args
=
parser
.
parse_args
()
# this script does not need GPU
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
tf
.
train
.
import_meta_graph
(
args
.
meta
,
clear_devices
=
True
)
# loading...
init
=
get_model_loader
(
args
.
input
)
sess
=
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
init
.
init
(
sess
)
# dump ...
with
sess
.
as_default
():
if
args
.
output
.
endswith
(
'npy'
)
or
args
.
output
.
endswith
(
'npz'
):
varmanip
.
dump_session_params
(
args
.
output
)
if
args
.
input
.
endswith
(
'.npz'
):
dic
=
np
.
load
(
args
.
input
)
else
:
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
gvars
=
set
([
k
.
name
for
k
in
tf
.
global_variables
()])
var
=
[
v
for
v
in
var
if
v
.
name
in
gvars
]
var_dict
=
{}
for
v
in
var
:
name
=
varmanip
.
get_savename_from_varname
(
v
.
name
)
var_dict
[
name
]
=
v
logger
.
info
(
"Variables to dump:"
)
logger
.
info
(
", "
.
join
(
var_dict
.
keys
()))
saver
=
tf
.
train
.
Saver
(
var_list
=
var_dict
,
write_version
=
tf
.
train
.
SaverDef
.
V2
)
saver
.
save
(
sess
,
args
.
output
,
write_meta_graph
=
False
)
dic
=
varmanip
.
load_chkpt_vars
(
args
.
input
)
dic
=
{
get_op_tensor_name
(
k
)[
1
]:
v
for
k
,
v
in
six
.
iteritems
(
dic
)}
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_to_dump
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
assert
len
(
set
(
var_to_dump
))
==
len
(
var_to_dump
),
"TRAINABLE and MODEL variables have duplication!"
globvarname
=
[
k
.
name
for
k
in
tf
.
global_variables
()]
var_to_dump
=
set
([
k
.
name
for
k
in
var_to_dump
if
k
.
name
in
globvarname
])
for
name
in
var_to_dump
:
assert
name
in
dic
,
"Variable {} not found in the model!"
.
format
(
name
)
dic_to_dump
=
{
k
:
v
for
k
,
v
in
six
.
iteritems
(
dic
)
if
k
in
var_to_dump
}
varmanip
.
save_chkpt_vars
(
dic_to_dump
,
args
.
output
)
tensorpack/callbacks/prof.py
View file @
8b4d4f77
...
...
@@ -35,6 +35,7 @@ class GPUUtilizationTracker(Callback):
Args:
devices (list[int]): physical GPU ids. If None, will use CUDA_VISIBLE_DEVICES
"""
assert
os
.
name
!=
'nt'
,
"GPUUtilizationTracker does not support windows!"
if
devices
is
None
:
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
if
env
is
None
:
...
...
tensorpack/dataflow/parallel.py
View file @
8b4d4f77
...
...
@@ -166,6 +166,8 @@ class MultiProcessPrefetchData(ProxyDataFlow):
nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use.
"""
if
os
.
name
==
'nt'
:
logger
.
warn
(
"MultiProcessPrefetchData may not support windows!"
)
super
(
MultiProcessPrefetchData
,
self
)
.
__init__
(
ds
)
try
:
self
.
_size
=
ds
.
size
()
...
...
tensorpack/tfutils/varmanip.py
View file @
8b4d4f77
...
...
@@ -117,6 +117,7 @@ def dump_session_params(path):
Args:
path(str): the file name to save the parameters. Must ends with npz.
"""
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
# TODO dedup
...
...
@@ -126,15 +127,33 @@ def dump_session_params(path):
result
=
{}
for
v
in
var
:
result
[
v
.
name
]
=
v
.
eval
()
save_chkpt_vars
(
result
,
path
)
def
save_chkpt_vars
(
dic
,
path
):
"""
Save variables in dic to path.
Args:
dic: {name: value}
path: save as npz if the name ends with '.npz', otherwise save as a checkpoint.
"""
logger
.
info
(
"Variables to save to {}:"
.
format
(
path
))
keys
=
sorted
(
list
(
result
.
keys
()))
keys
=
sorted
(
list
(
dic
.
keys
()))
logger
.
info
(
pprint
.
pformat
(
keys
))
if
path
.
endswith
(
'.npy'
):
np
.
save
(
path
,
result
)
el
if
path
.
endswith
(
'.npz'
):
np
.
savez_compressed
(
path
,
**
result
)
assert
not
path
.
endswith
(
'.npy'
)
if
path
.
endswith
(
'.npz'
):
np
.
savez_compressed
(
path
,
**
dic
)
else
:
raise
ValueError
(
"Don't know which format to use for {}"
.
format
(
path
))
with
tf
.
Graph
()
.
as_default
(),
\
tf
.
Session
()
as
sess
:
for
k
,
v
in
six
.
iteritems
(
dic
):
k
=
get_op_tensor_name
(
k
)[
0
]
_
=
tf
.
Variable
(
name
=
k
,
initial_value
=
v
)
# noqa
sess
.
run
(
tf
.
global_variables_initializer
())
saver
=
tf
.
train
.
Saver
()
saver
.
save
(
sess
,
path
,
write_meta_graph
=
False
)
def
get_checkpoint_path
(
model_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment