Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e68eec29
Commit
e68eec29
authored
Sep 04, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better warning about user-provided sessconfig
parent
6a0bba68
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
26 deletions
+43
-26
scripts/dump-model-params.py
scripts/dump-model-params.py
+26
-23
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+17
-3
No files found.
scripts/dump-model-params.py
View file @
e68eec29
...
...
@@ -75,24 +75,24 @@ if __name__ == '__main__':
if
os
.
path
.
isdir
(
args
.
input
):
input
,
meta
=
guess_inputs
(
args
.
input
)
else
:
assert
args
.
meta
is
not
None
meta
=
args
.
meta
input
=
args
.
input
# this script does not need GPU
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
while
True
:
try
:
tf
.
reset_default_graph
()
tf
.
train
.
import_meta_graph
(
meta
,
clear_devices
=
True
)
except
KeyError
as
e
:
op_name
=
e
.
args
[
0
]
_import_external_ops
(
op_name
)
except
tf
.
errors
.
NotFoundError
as
e
:
_import_external_ops
(
e
.
message
)
else
:
break
if
args
.
meta
is
not
None
:
while
True
:
try
:
tf
.
reset_default_graph
()
tf
.
train
.
import_meta_graph
(
meta
,
clear_devices
=
True
)
except
KeyError
as
e
:
op_name
=
e
.
args
[
0
]
_import_external_ops
(
op_name
)
except
tf
.
errors
.
NotFoundError
as
e
:
_import_external_ops
(
e
.
message
)
else
:
break
# loading...
if
input
.
endswith
(
'.npz'
):
...
...
@@ -101,17 +101,20 @@ if __name__ == '__main__':
dic
=
varmanip
.
load_chkpt_vars
(
input
)
dic
=
{
get_op_tensor_name
(
k
)[
1
]:
v
for
k
,
v
in
six
.
iteritems
(
dic
)}
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_to_dump
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
if
len
(
set
(
var_to_dump
))
!=
len
(
var_to_dump
):
logger
.
warn
(
"TRAINABLE and MODEL variables have duplication!"
)
var_to_dump
=
list
(
set
(
var_to_dump
))
globvarname
=
set
([
k
.
name
for
k
in
tf
.
global_variables
()])
var_to_dump
=
set
([
k
.
name
for
k
in
var_to_dump
if
k
.
name
in
globvarname
])
for
name
in
var_to_dump
:
assert
name
in
dic
,
"Variable {} not found in the model!"
.
format
(
name
)
if
args
.
meta
is
not
None
:
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_to_dump
.
extend
(
tf
.
get_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
))
if
len
(
set
(
var_to_dump
))
!=
len
(
var_to_dump
):
logger
.
warn
(
"TRAINABLE and MODEL variables have duplication!"
)
var_to_dump
=
list
(
set
(
var_to_dump
))
globvarname
=
set
([
k
.
name
for
k
in
tf
.
global_variables
()])
var_to_dump
=
set
([
k
.
name
for
k
in
var_to_dump
if
k
.
name
in
globvarname
])
for
name
in
var_to_dump
:
assert
name
in
dic
,
"Variable {} not found in the model!"
.
format
(
name
)
else
:
var_to_dump
=
set
(
dic
.
keys
())
dic_to_dump
=
{
k
:
v
for
k
,
v
in
six
.
iteritems
(
dic
)
if
k
in
var_to_dump
}
varmanip
.
save_chkpt_vars
(
dic_to_dump
,
args
.
output
)
tensorpack/tfutils/sesscreate.py
View file @
e68eec29
...
...
@@ -17,6 +17,21 @@ A SessionCreator should:
"""
_WRN1
=
"""User-provided custom session config may not work due to TF bugs. If you saw logs like
```
tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties:
```
before this line, then your GPU has been initialized and custom GPU options may not take effect. """
_WRN2
=
"""To workaround this issue, you can do one of the following:
1. Avoid initializing the GPU too early. Find code that initializes the GPU and skip it.
Typically examples are: creating a session; check GPU availability; check GPU number.
2. Manually set your GPU options earlier. You can create a session with custom
GPU options at the beginning of your program, as described in
https://github.com/tensorpack/tensorpack/issues/497
"""
class
NewSessionCreator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
target
=
''
,
config
=
None
):
"""
...
...
@@ -33,9 +48,8 @@ class NewSessionCreator(tf.train.SessionCreator):
config
=
get_default_sess_config
()
else
:
self
.
user_provided_config
=
True
logger
.
warn
(
"User-provided custom session config may not work due to TF
\
bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds."
)
logger
.
warn
(
_WRN1
)
logger
.
warn
(
_WRN2
)
self
.
config
=
config
def
create_session
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment