Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c44b65fc
Commit
c44b65fc
authored
Jan 03, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve error messages in TrainConfig/PredictConfig type checking (#1029)
parent
2ce43d70
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
23 deletions
+29
-23
examples/basics/export-model.py
examples/basics/export-model.py
+6
-6
scripts/ls-checkpoint.py
scripts/ls-checkpoint.py
+1
-1
tensorpack/predict/config.py
tensorpack/predict/config.py
+10
-7
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+1
-0
tensorpack/train/config.py
tensorpack/train/config.py
+11
-9
No files found.
examples/basics/export-model.py
View file @
c44b65fc
...
...
@@ -19,18 +19,18 @@ The steps are:
1. train the model by
python export.py
python export
-model
.py
2. export the model by
python export.py --export serving --load train_log/export/checkpoint
python export.py --export compact --load train_log/export/checkpoint
python export
-model
.py --export serving --load train_log/export/checkpoint
python export
-model
.py --export compact --load train_log/export/checkpoint
3. run inference by
python export.py --apply default --load train_log/export/checkpoint
python export.py --apply inference_graph --load train_log/export/checkpoint
python export.py --apply compact --load /tmp/compact_graph.pb
python export
-model
.py --apply default --load train_log/export/checkpoint
python export
-model
.py --apply inference_graph --load train_log/export/checkpoint
python export
-model
.py --apply compact --load /tmp/compact_graph.pb
"""
...
...
scripts/ls-checkpoint.py
View file @
c44b65fc
...
...
@@ -20,7 +20,7 @@ if __name__ == '__main__':
params
=
dict
(
np
.
load
(
fpath
))
dic
=
{
k
:
v
.
shape
for
k
,
v
in
six
.
iteritems
(
params
)}
else
:
path
=
get_checkpoint_path
(
sys
.
argv
[
1
]
)
path
=
get_checkpoint_path
(
fpath
)
reader
=
tf
.
train
.
NewCheckpointReader
(
path
)
dic
=
reader
.
get_variable_to_shape_map
()
pprint
.
pprint
(
dic
)
tensorpack/predict/config.py
View file @
c44b65fc
...
...
@@ -53,10 +53,13 @@ class PredictConfig(object):
create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized.
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
def
assert_type
(
v
,
tp
,
name
):
assert
isinstance
(
v
,
tp
),
\
"{} has to be type '{}', but an object of type '{}' found."
.
format
(
name
,
tp
.
__name__
,
v
.
__class__
.
__name__
)
if
model
is
not
None
:
assert_type
(
model
,
ModelDescBase
)
assert_type
(
model
,
ModelDescBase
,
'model'
)
assert
inputs_desc
is
None
and
tower_func
is
None
self
.
inputs_desc
=
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
self
.
inputs_desc
)
...
...
@@ -70,7 +73,7 @@ class PredictConfig(object):
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
,
'session_init'
)
if
session_creator
is
None
:
self
.
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
config
=
get_default_sess_config
())
...
...
@@ -82,13 +85,13 @@ class PredictConfig(object):
if
self
.
input_names
is
None
:
self
.
input_names
=
[
k
.
name
for
k
in
self
.
inputs_desc
]
self
.
output_names
=
output_names
assert_type
(
self
.
output_names
,
list
)
assert_type
(
self
.
input_names
,
list
)
assert_type
(
self
.
output_names
,
list
,
'output_names'
)
assert_type
(
self
.
input_names
,
list
,
'input_names'
)
if
len
(
self
.
input_names
)
==
0
:
logger
.
warn
(
'PredictConfig receives empty "input_names".'
)
# assert len(self.input_names), self.input_names
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
assert_type
(
v
,
six
.
string_types
,
'Each item in input_names'
)
assert
len
(
self
.
output_names
),
self
.
output_names
self
.
return_input
=
bool
(
return_input
)
...
...
tensorpack/tfutils/sessinit.py
View file @
c44b65fc
...
...
@@ -248,6 +248,7 @@ def get_model_loader(filename):
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise).
"""
assert
isinstance
(
filename
,
six
.
string_types
),
filename
if
filename
.
endswith
(
'.npy'
):
assert
tf
.
gfile
.
Exists
(
filename
),
filename
return
DictRestore
(
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
())
...
...
tensorpack/train/config.py
View file @
c44b65fc
...
...
@@ -98,33 +98,35 @@ class TrainConfig(object):
"""
# TODO type checker decorator
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
def
assert_type
(
v
,
tp
,
name
):
assert
isinstance
(
v
,
tp
),
\
"{} has to be type '{}', but an object of type '{}' found."
.
format
(
name
,
tp
.
__name__
,
v
.
__class__
.
__name__
)
# process data & model
assert
data
is
None
or
dataflow
is
None
,
"dataflow and data cannot be both presented in TrainConfig!"
if
dataflow
is
not
None
:
assert_type
(
dataflow
,
DataFlow
)
assert_type
(
dataflow
,
DataFlow
,
'dataflow'
)
if
data
is
not
None
:
assert_type
(
data
,
InputSource
)
assert_type
(
data
,
InputSource
,
'data'
)
self
.
dataflow
=
dataflow
self
.
data
=
data
if
model
is
not
None
:
assert_type
(
model
,
ModelDescBase
)
assert_type
(
model
,
ModelDescBase
,
'model'
)
self
.
model
=
model
if
callbacks
is
not
None
:
assert_type
(
callbacks
,
list
)
assert_type
(
callbacks
,
list
,
'callbacks'
)
self
.
callbacks
=
callbacks
if
extra_callbacks
is
not
None
:
assert_type
(
extra_callbacks
,
list
)
assert_type
(
extra_callbacks
,
list
,
'extra_callbacks'
)
self
.
extra_callbacks
=
extra_callbacks
if
monitors
is
not
None
:
assert_type
(
monitors
,
list
)
assert_type
(
monitors
,
list
,
'monitors'
)
self
.
monitors
=
monitors
if
session_init
is
not
None
:
assert_type
(
session_init
,
SessionInit
)
assert_type
(
session_init
,
SessionInit
,
'session_init'
)
self
.
session_init
=
session_init
if
session_creator
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment