Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8b4d4f77
Commit
8b4d4f77
authored
Apr 05, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dump-model without GPU. some checks on windows support
parent
c8a9e4e5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
33 deletions
+54
-33
scripts/dump-model-params.py
scripts/dump-model-params.py
+26
-27
tensorpack/callbacks/prof.py
tensorpack/callbacks/prof.py
+1
-0
tensorpack/dataflow/parallel.py
tensorpack/dataflow/parallel.py
+2
-0
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+25
-6
No files found.
scripts/dump-model-params.py
View file @
8b4d4f77
...
@@ -3,11 +3,14 @@
...
@@ -3,11 +3,14 @@
# File: dump-model-params.py
# File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
six
import
argparse
import
argparse
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorpack
import
logger
from
tensorpack
.tfutils
import
varmanip
from
tensorpack.tfutils
import
varmanip
,
get_model_loader
from
tensorpack.tfutils
.common
import
get_op_tensor_name
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -17,31 +20,27 @@ if __name__ == '__main__':
...
@@ -17,31 +20,27 @@ if __name__ == '__main__':
parser
.
add_argument
(
dest
=
'output'
,
help
=
'output model file, can be npz or TF checkpoint'
)
parser
.
add_argument
(
dest
=
'output'
,
help
=
'output model file, can be npz or TF checkpoint'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# this script does not need GPU
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
tf
.
train
.
import_meta_graph
(
args
.
meta
,
clear_devices
=
True
)
tf
.
train
.
import_meta_graph
(
args
.
meta
,
clear_devices
=
True
)
# loading...
# loading...
init
=
get_model_loader
(
args
.
input
)
if
args
.
input
.
endswith
(
'.npz'
):
sess
=
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
dic
=
np
.
load
(
args
.
input
)
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
init
.
init
(
sess
)
# dump ...
with
sess
.
as_default
():
if
args
.
output
.
endswith
(
'npy'
)
or
args
.
output
.
endswith
(
'npz'
):
varmanip
.
dump_session_params
(
args
.
output
)
else
:
else
:
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
dic
=
varmanip
.
load_chkpt_vars
(
args
.
input
)
var
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
dic
=
{
get_op_tensor_name
(
k
)[
1
]:
v
for
k
,
v
in
six
.
iteritems
(
dic
)}
gvars
=
set
([
k
.
name
for
k
in
tf
.
global_variables
()])
var
=
[
v
for
v
in
var
if
v
.
name
in
gvars
]
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_dict
=
{}
var_to_dump
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
for
v
in
var
:
var_to_dump
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
name
=
varmanip
.
get_savename_from_varname
(
v
.
name
)
assert
len
(
set
(
var_to_dump
))
==
len
(
var_to_dump
),
"TRAINABLE and MODEL variables have duplication!"
var_dict
[
name
]
=
v
globvarname
=
[
k
.
name
for
k
in
tf
.
global_variables
()]
logger
.
info
(
"Variables to dump:"
)
var_to_dump
=
set
([
k
.
name
for
k
in
var_to_dump
if
k
.
name
in
globvarname
])
logger
.
info
(
", "
.
join
(
var_dict
.
keys
()))
saver
=
tf
.
train
.
Saver
(
for
name
in
var_to_dump
:
var_list
=
var_dict
,
assert
name
in
dic
,
"Variable {} not found in the model!"
.
format
(
name
)
write_version
=
tf
.
train
.
SaverDef
.
V2
)
saver
.
save
(
sess
,
args
.
output
,
write_meta_graph
=
False
)
dic_to_dump
=
{
k
:
v
for
k
,
v
in
six
.
iteritems
(
dic
)
if
k
in
var_to_dump
}
varmanip
.
save_chkpt_vars
(
dic_to_dump
,
args
.
output
)
tensorpack/callbacks/prof.py
View file @
8b4d4f77
...
@@ -35,6 +35,7 @@ class GPUUtilizationTracker(Callback):
...
@@ -35,6 +35,7 @@ class GPUUtilizationTracker(Callback):
Args:
Args:
devices (list[int]): physical GPU ids. If None, will use CUDA_VISIBLE_DEVICES
devices (list[int]): physical GPU ids. If None, will use CUDA_VISIBLE_DEVICES
"""
"""
assert
os
.
name
!=
'nt'
,
"GPUUtilizationTracker does not support windows!"
if
devices
is
None
:
if
devices
is
None
:
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
if
env
is
None
:
if
env
is
None
:
...
...
tensorpack/dataflow/parallel.py
View file @
8b4d4f77
...
@@ -166,6 +166,8 @@ class MultiProcessPrefetchData(ProxyDataFlow):
...
@@ -166,6 +166,8 @@ class MultiProcessPrefetchData(ProxyDataFlow):
nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use.
nr_proc (int): number of processes to use.
"""
"""
if
os
.
name
==
'nt'
:
logger
.
warn
(
"MultiProcessPrefetchData may not support windows!"
)
super
(
MultiProcessPrefetchData
,
self
)
.
__init__
(
ds
)
super
(
MultiProcessPrefetchData
,
self
)
.
__init__
(
ds
)
try
:
try
:
self
.
_size
=
ds
.
size
()
self
.
_size
=
ds
.
size
()
...
...
tensorpack/tfutils/varmanip.py
View file @
8b4d4f77
...
@@ -117,6 +117,7 @@ def dump_session_params(path):
...
@@ -117,6 +117,7 @@ def dump_session_params(path):
Args:
Args:
path(str): the file name to save the parameters. Must ends with npz.
path(str): the file name to save the parameters. Must ends with npz.
"""
"""
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
var
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
# TODO dedup
# TODO dedup
...
@@ -126,15 +127,33 @@ def dump_session_params(path):
...
@@ -126,15 +127,33 @@ def dump_session_params(path):
result
=
{}
result
=
{}
for
v
in
var
:
for
v
in
var
:
result
[
v
.
name
]
=
v
.
eval
()
result
[
v
.
name
]
=
v
.
eval
()
save_chkpt_vars
(
result
,
path
)
def
save_chkpt_vars
(
dic
,
path
):
"""
Save variables in dic to path.
Args:
dic: {name: value}
path: save as npz if the name ends with '.npz', otherwise save as a checkpoint.
"""
logger
.
info
(
"Variables to save to {}:"
.
format
(
path
))
logger
.
info
(
"Variables to save to {}:"
.
format
(
path
))
keys
=
sorted
(
list
(
result
.
keys
()))
keys
=
sorted
(
list
(
dic
.
keys
()))
logger
.
info
(
pprint
.
pformat
(
keys
))
logger
.
info
(
pprint
.
pformat
(
keys
))
if
path
.
endswith
(
'.npy'
):
np
.
save
(
path
,
result
)
assert
not
path
.
endswith
(
'.npy'
)
el
if
path
.
endswith
(
'.npz'
):
if
path
.
endswith
(
'.npz'
):
np
.
savez_compressed
(
path
,
**
result
)
np
.
savez_compressed
(
path
,
**
dic
)
else
:
else
:
raise
ValueError
(
"Don't know which format to use for {}"
.
format
(
path
))
with
tf
.
Graph
()
.
as_default
(),
\
tf
.
Session
()
as
sess
:
for
k
,
v
in
six
.
iteritems
(
dic
):
k
=
get_op_tensor_name
(
k
)[
0
]
_
=
tf
.
Variable
(
name
=
k
,
initial_value
=
v
)
# noqa
sess
.
run
(
tf
.
global_variables_initializer
())
saver
=
tf
.
train
.
Saver
()
saver
.
save
(
sess
,
path
,
write_meta_graph
=
False
)
def
get_checkpoint_path
(
model_path
):
def
get_checkpoint_path
(
model_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment