Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c5df0501
Commit
c5df0501
authored
Dec 31, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
input_dataset_mapping for prediction
parent
9bb0b8f6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
6 deletions
+26
-6
example_alexnet.py
example_alexnet.py
+2
-1
tensorpack/predict.py
tensorpack/predict.py
+24
-5
No files found.
example_alexnet.py
View file @
c5df0501
...
...
@@ -135,6 +135,7 @@ def run_test(path):
pred_config
=
PredictConfig
(
inputs
=
input_vars
,
input_dataset_mapping
=
[
input_vars
[
0
]],
get_model_func
=
get_model
,
session_init
=
ParamRestore
(
param_dict
),
output_var_names
=
[
'output:0'
]
# output:0 is the probability distribution
...
...
@@ -146,7 +147,7 @@ def run_test(path):
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
cv2
.
resize
(
im
,
(
227
,
227
))
im
=
np
.
reshape
(
im
,
(
1
,
227
,
227
,
3
))
outputs
=
predict_func
([
im
,
(
1
,)
])[
0
]
outputs
=
predict_func
([
im
])[
0
]
prob
=
outputs
[
0
]
print
prob
.
shape
print
prob
.
argsort
()[
-
10
:][::
-
1
]
...
...
tensorpack/predict.py
View file @
c5df0501
...
...
@@ -22,8 +22,22 @@ class PredictConfig(object):
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session.
inputs: a list of input variables. must match the dataset later
used for prediction.
inputs: input variables of the graph.
input_dataset_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example
the `label` input is not used if you only need probability
distribution). It should be a list with size=len(one_data_point),
where each element is a tensor which each component of the
data point should be fed into.
If not given, defaults to `inputs`.
For example, with image classification task, the testing
dataset only provides datapoints of images (no labels). The
arguments should look like:
inputs: [image_var, label_var]
input_dataset_mapping: [image_var]
If this argument is not set, the inputs and the data points won't be aligned.
get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize
output_var_names: a list of names of the output variable to predict, the
...
...
@@ -38,7 +52,7 @@ class PredictConfig(object):
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
inputs
=
kwargs
.
pop
(
'inputs'
)
[
assert_type
(
i
,
tf
.
Tensor
)
for
i
in
self
.
inputs
]
self
.
input_dataset_mapping
=
kwargs
.
pop
(
'input_dataset_mapping'
,
None
)
self
.
get_model_func
=
kwargs
.
pop
(
'get_model_func'
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
,
None
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
...
...
@@ -59,6 +73,9 @@ def get_predict_func(config):
# input/output variables
input_vars
=
config
.
inputs
output_vars
,
cost_var
=
config
.
get_model_func
(
input_vars
,
is_training
=
False
)
input_map
=
config
.
input_dataset_mapping
if
input_map
is
None
:
input_map
=
input_vars
# check output_var_names against output_vars
if
output_var_names
is
not
None
:
...
...
@@ -70,8 +87,10 @@ def get_predict_func(config):
config
.
session_init
.
init
(
sess
)
def
run_input
(
dp
):
# TODO if input and dp not aligned?
feed
=
dict
(
zip
(
input_vars
,
dp
))
assert
len
(
input_map
)
==
len
(
dp
),
\
"Graph has {} inputs but dataset only gives {} components!"
.
format
(
len
(
input_map
),
len
(
dp
))
feed
=
dict
(
zip
(
input_map
,
dp
))
if
output_var_names
is
not
None
:
results
=
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
results
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment