Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
19069210
Commit
19069210
authored
Feb 09, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fix
parent
a3de7ec7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
9 deletions
+12
-9
example_alexnet.py
example_alexnet.py
+6
-4
tensorpack/predict.py
tensorpack/predict.py
+6
-5
No files found.
example_alexnet.py
View file @
19069210
...
...
@@ -124,7 +124,7 @@ def run_test(path):
param_dict
=
np
.
load
(
path
)
.
item
()
pred_config
=
PredictConfig
(
model
=
Model
s
(),
model
=
Model
(),
input_data_mapping
=
[
0
],
session_init
=
ParamRestore
(
param_dict
),
output_var_names
=
[
'output:0'
]
# output:0 is the probability distribution
...
...
@@ -139,7 +139,9 @@ def run_test(path):
outputs
=
predict_func
([
im
])[
0
]
prob
=
outputs
[
0
]
print
prob
.
shape
print
prob
.
argsort
()[
-
10
:][::
-
1
]
ret
=
prob
.
argsort
()[
-
10
:][::
-
1
]
print
ret
assert
ret
[
0
]
==
285
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -148,7 +150,7 @@ if __name__ == '__main__':
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
start_train
(
get_config
())
#
start_train(get_config())
# run alexnet with given model (in npy format)
run_test
(
'alexnet
-tuned
.npy'
)
run_test
(
'alexnet.npy'
)
tensorpack/predict.py
View file @
19069210
...
...
@@ -35,11 +35,12 @@ class PredictConfig(object):
If not given, defaults to range(len(input_vars))
For example, with image classification task, the testing
dataset only provides datapoints of images (no labels). The
arguments should look like:
inputs: [image_var, label_var]
dataset only provides datapoints of images (no labels). When
the input variables of the model is:
input_vars: [image_var, label_var]
the mapping should look like:
input_data_mapping: [0]
If this argument is not set, the inputs and the data points won't be aligned.
If this argument is not set
in this case
, the inputs and the data points won't be aligned.
model: a ModelDesc instance
output_var_names: a list of names of the output variable to predict, the
variables can be any computable tensor in the graph.
...
...
@@ -53,7 +54,7 @@ class PredictConfig(object):
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data
set
_mapping'
,
None
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
,
None
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment