Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d5410902
Commit
d5410902
authored
Jul 16, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use input_names in predictconfig
parent
cba97f75
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
35 additions
and
39 deletions
+35
-39
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-0
examples/Atari2600/common.py
examples/Atari2600/common.py
+1
-1
examples/DoReFa-Net/alexnet.py
examples/DoReFa-Net/alexnet.py
+1
-1
examples/load-alexnet.py
examples/load-alexnet.py
+1
-1
examples/load-vgg16.py
examples/load-vgg16.py
+1
-1
scripts/imgclassify.py
scripts/imgclassify.py
+3
-3
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+4
-0
tensorpack/predict/common.py
tensorpack/predict/common.py
+22
-31
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+1
-1
No files found.
examples/Atari2600/DQN.py
View file @
d5410902
...
@@ -196,6 +196,7 @@ if __name__ == '__main__':
...
@@ -196,6 +196,7 @@ if __name__ == '__main__':
cfg
=
PredictConfig
(
cfg
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
session_init
=
SaverRestore
(
args
.
load
),
input_var_names
=
[
'state'
]
output_var_names
=
[
'fct/output:0'
])
output_var_names
=
[
'fct/output:0'
])
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
play_model
(
cfg
)
play_model
(
cfg
)
...
...
examples/Atari2600/common.py
View file @
d5410902
...
@@ -9,7 +9,7 @@ from tqdm import tqdm
...
@@ -9,7 +9,7 @@ from tqdm import tqdm
from
six.moves
import
queue
from
six.moves
import
queue
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
MultiProcessPredictWorker
from
tensorpack.predict
import
get_predict_func
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.stat
import
*
from
tensorpack.utils.stat
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.callbacks
import
*
...
...
examples/DoReFa-Net/alexnet.py
View file @
d5410902
...
@@ -104,7 +104,7 @@ def eval_on_ILSVRC12(model, sess_init, data_dir):
...
@@ -104,7 +104,7 @@ def eval_on_ILSVRC12(model, sess_init, data_dir):
def
run_test
(
model
,
sess_init
,
inputs
):
def
run_test
(
model
,
sess_init
,
inputs
):
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
model
=
model
,
model
=
model
,
input_
data_mapping
=
[
0
],
input_
var_names
=
[
'input'
],
session_init
=
sess_init
,
session_init
=
sess_init
,
session_config
=
get_default_sess_config
(
0.9
),
session_config
=
get_default_sess_config
(
0.9
),
output_var_names
=
[
'prob:0'
]
output_var_names
=
[
'prob:0'
]
...
...
examples/load-alexnet.py
View file @
d5410902
...
@@ -59,7 +59,7 @@ def run_test(path, input):
...
@@ -59,7 +59,7 @@ def run_test(path, input):
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
input_
data_mapping
=
[
0
],
input_
var_names
=
[
'input'
],
session_init
=
ParamRestore
(
param_dict
),
session_init
=
ParamRestore
(
param_dict
),
session_config
=
get_default_sess_config
(
0.9
),
session_config
=
get_default_sess_config
(
0.9
),
output_var_names
=
[
'output:0'
]
# output:0 is the probability distribution
output_var_names
=
[
'output:0'
]
# output:0 is the probability distribution
...
...
examples/load-vgg16.py
View file @
d5410902
...
@@ -76,7 +76,7 @@ def run_test(path, input):
...
@@ -76,7 +76,7 @@ def run_test(path, input):
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
input_
data_mapping
=
[
0
],
input_
var_names
=
[
'input'
],
session_init
=
ParamRestore
(
param_dict
),
session_init
=
ParamRestore
(
param_dict
),
session_config
=
get_default_sess_config
(
0.9
),
session_config
=
get_default_sess_config
(
0.9
),
output_var_names
=
[
'output:0'
]
# output:0 is the probability distribution
output_var_names
=
[
'output:0'
]
# output:0 is the probability distribution
...
...
scripts/imgclassify.py
View file @
d5410902
...
@@ -30,10 +30,10 @@ get_config_func = imp.load_source('config_script', args.config).get_config
...
@@ -30,10 +30,10 @@ get_config_func = imp.load_source('config_script', args.config).get_config
with
tf
.
Graph
()
.
as_default
()
as
G
:
with
tf
.
Graph
()
.
as_default
()
as
G
:
train_config
=
get_config_func
()
train_config
=
get_config_func
()
M
=
train_config
.
model
config
=
PredictConfig
(
config
=
PredictConfig
(
inputs
=
train_config
.
inputs
,
input_var_names
=
[
M
.
get_input_vars_desc
()[
0
]
.
name
],
# assume first component is image
input_dataset_mapping
=
[
train_config
.
inputs
[
0
]],
# assume first component is image
model
=
M
,
get_model_func
=
train_config
.
get_model_func
,
session_init
=
sessinit
.
SaverRestore
(
args
.
model
),
session_init
=
sessinit
.
SaverRestore
(
args
.
model
),
output_var_names
=
[
'output:0'
]
output_var_names
=
[
'output:0'
]
)
)
...
...
tensorpack/models/model_desc.py
View file @
d5410902
...
@@ -42,6 +42,10 @@ class ModelDesc(object):
...
@@ -42,6 +42,10 @@ class ModelDesc(object):
g
=
tf
.
get_default_graph
()
g
=
tf
.
get_default_graph
()
return
[
g
.
get_tensor_by_name
(
name
+
":0"
)
for
name
in
input_var_names
]
return
[
g
.
get_tensor_by_name
(
name
+
":0"
)
for
name
in
input_var_names
]
def
get_input_vars_desc
(
self
):
""" return a list of `InputVar` instance"""
return
self
.
_get_input_vars
()
@
abstractmethod
@
abstractmethod
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
""":returns: a list of InputVar """
""":returns: a list of InputVar """
...
...
tensorpack/predict/common.py
View file @
d5410902
...
@@ -7,6 +7,7 @@ from collections import namedtuple
...
@@ -7,6 +7,7 @@ from collections import namedtuple
from
six.moves
import
zip
from
six.moves
import
zip
from
tensorpack.models
import
ModelDesc
from
tensorpack.models
import
ModelDesc
from
..utils
import
logger
from
..tfutils
import
*
from
..tfutils
import
*
import
multiprocessing
import
multiprocessing
...
@@ -22,26 +23,8 @@ class PredictConfig(object):
...
@@ -22,26 +23,8 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
initialize variables of a session.
:param input_data_mapping: Decide the mapping from each component in data
:param input_var_names: a list of input variable names.
to the input tensor, since you may not need all input variables
:param input_data_mapping: deprecated. used to select `input_var_names` from the `InputVars` of the model.
of the Model to run the graph for prediction (for example
the `label` input is not used if you only need probability distribution).
It should be a list of int with length equal to `len(data_point)`,
where each element in the list defines which input variables each
component in the data point should be fed into.
If not given, defaults to range(len(input_vars))
For example, in image classification task, the testing
dataset only provides datapoints of images (no labels). When
the input variables of the model is: ::
input_vars: [image_var, label_var]
the mapping should then look like: ::
input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
:param model: a `ModelDesc` instance
:param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the
:param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
variables can be any computable tensor in the graph.
...
@@ -58,8 +41,21 @@ class PredictConfig(object):
...
@@ -58,8 +41,21 @@ class PredictConfig(object):
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
assert_type
(
self
.
model
,
ModelDesc
)
assert_type
(
self
.
model
,
ModelDesc
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
input_var_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
input_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
if
input_mapping
:
raw_vars
=
self
.
model
.
get_input_vars_desc
()
self
.
input_var_names
=
[
raw_vars
[
k
]
.
name
for
k
in
input_mapping
]
logger
.
warn
(
'The option `input_data_mapping` was deprecated.
\
Use
\'
input_var_names=[{}]
\'
instead'
.
format
(
', '
.
join
(
self
.
input_var_names
)))
elif
self
.
input_var_names
is
None
:
# neither options is set, assume all inputs
raw_vars
=
self
.
model
.
get_input_vars_desc
()
self
.
input_var_names
=
[
k
.
name
for
k
in
raw_vars
]
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
assert
len
(
self
.
input_var_names
),
self
.
input_var_names
assert
len
(
self
.
output_var_names
),
self
.
output_var_names
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
...
@@ -71,24 +67,19 @@ def get_predict_func(config):
...
@@ -71,24 +67,19 @@ def get_predict_func(config):
:returns: A prediction function that takes a list of input values, and return
:returns: A prediction function that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
a list of output values defined in ``config.output_var_names``.
"""
"""
output_var_names
=
config
.
output_var_names
# build graph
# input/output variables
input_vars
=
config
.
model
.
get_input_vars
()
input_vars
=
config
.
model
.
get_input_vars
()
config
.
model
.
_build_graph
(
input_vars
,
False
)
config
.
model
.
_build_graph
(
input_vars
,
False
)
if
config
.
input_data_mapping
is
None
:
input_map
=
input_vars
else
:
input_map
=
[
input_vars
[
k
]
for
k
in
config
.
input_data_mapping
if
k
>=
0
]
# check output_var_names against output_vars
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
output_vars
=
get_vars_by_names
(
output_var_names
)
output_vars
=
get_vars_by_names
(
config
.
output_var_names
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
def
run_input
(
dp
):
def
run_input
(
dp
):
feed
=
dict
(
zip
(
input_map
,
dp
))
assert
len
(
input_vars
)
==
len
(
dp
),
"{} != {}"
.
format
(
len
(
input_vars
),
len
(
dp
))
feed
=
dict
(
zip
(
input_vars
,
dp
))
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
# XXX hack. so the caller can get access to the session.
# XXX hack. so the caller can get access to the session.
run_input
.
session
=
sess
run_input
.
session
=
sess
...
...
tensorpack/predict/concurrency.py
View file @
d5410902
...
@@ -96,7 +96,7 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -96,7 +96,7 @@ class PredictorWorkerThread(threading.Thread):
while
True
:
while
True
:
batched
,
futures
=
self
.
fetch_batch
()
batched
,
futures
=
self
.
fetch_batch
()
outputs
=
self
.
func
(
batched
)
outputs
=
self
.
func
(
batched
)
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
#print "batched size: ", len(batched
[0]
), "queuesize: ", self.queue.qsize()
# debug, for speed testing
# debug, for speed testing
#if self.xxx is None:
#if self.xxx is None:
#self.xxx = outputs = self.func([batched])
#self.xxx = outputs = self.func([batched])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment