Commit d5410902 authored by Yuxin Wu's avatar Yuxin Wu

use input_names in predictconfig

parent cba97f75
......@@ -196,6 +196,7 @@ if __name__ == '__main__':
cfg = PredictConfig(
model=Model(),
session_init=SaverRestore(args.load),
input_var_names=['state']
output_var_names=['fct/output:0'])
if args.task == 'play':
play_model(cfg)
......
......@@ -9,7 +9,7 @@ from tqdm import tqdm
from six.moves import queue
from tensorpack import *
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker
from tensorpack.predict import get_predict_func
from tensorpack.utils.concurrency import *
from tensorpack.utils.stat import *
from tensorpack.callbacks import *
......
......@@ -104,7 +104,7 @@ def eval_on_ILSVRC12(model, sess_init, data_dir):
def run_test(model, sess_init, inputs):
pred_config = PredictConfig(
model=model,
input_data_mapping=[0],
input_var_names=['input'],
session_init=sess_init,
session_config=get_default_sess_config(0.9),
output_var_names=['prob:0']
......
......@@ -59,7 +59,7 @@ def run_test(path, input):
pred_config = PredictConfig(
model=Model(),
input_data_mapping=[0],
input_var_names=['input'],
session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9),
output_var_names=['output:0'] # output:0 is the probability distribution
......
......@@ -76,7 +76,7 @@ def run_test(path, input):
pred_config = PredictConfig(
model=Model(),
input_data_mapping=[0],
input_var_names=['input'],
session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9),
output_var_names=['output:0'] # output:0 is the probability distribution
......
......@@ -30,10 +30,10 @@ get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G:
train_config = get_config_func()
M = train_config.model
config = PredictConfig(
inputs=train_config.inputs,
input_dataset_mapping=[train_config.inputs[0]], # assume first component is image
get_model_func=train_config.get_model_func,
input_var_names=[M.get_input_vars_desc()[0].name], # assume first component is image
model=M,
session_init=sessinit.SaverRestore(args.model),
output_var_names=['output:0']
)
......
......@@ -42,6 +42,10 @@ class ModelDesc(object):
g = tf.get_default_graph()
return [g.get_tensor_by_name(name + ":0") for name in input_var_names]
def get_input_vars_desc(self):
""" return a list of `InputVar` instance"""
return self._get_input_vars()
@abstractmethod
def _get_input_vars(self):
""":returns: a list of InputVar """
......
......@@ -7,6 +7,7 @@ from collections import namedtuple
from six.moves import zip
from tensorpack.models import ModelDesc
from ..utils import logger
from ..tfutils import *
import multiprocessing
......@@ -22,26 +23,8 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param input_data_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables
of the Model to run the graph for prediction (for example
the `label` input is not used if you only need probability distribution).
It should be a list of int with length equal to `len(data_point)`,
where each element in the list defines which input variables each
component in the data point should be fed into.
If not given, defaults to range(len(input_vars))
For example, in image classification task, the testing
dataset only provides datapoints of images (no labels). When
the input variables of the model is: ::
input_vars: [image_var, label_var]
the mapping should then look like: ::
input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
:param input_var_names: a list of input variable names.
:param input_data_mapping: deprecated. used to select `input_var_names` from the `InputVars` of the model.
:param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
......@@ -58,8 +41,21 @@ class PredictConfig(object):
assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.input_data_mapping = kwargs.pop('input_data_mapping', None)
self.input_var_names = kwargs.pop('input_var_names', None)
input_mapping = kwargs.pop('input_data_mapping', None)
if input_mapping:
raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [raw_vars[k].name for k in input_mapping]
logger.warn('The option `input_data_mapping` was deprecated. \
Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
elif self.input_var_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [k.name for k in raw_vars]
self.output_var_names = kwargs.pop('output_var_names')
assert len(self.input_var_names), self.input_var_names
assert len(self.output_var_names), self.output_var_names
self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......@@ -71,24 +67,19 @@ def get_predict_func(config):
:returns: A prediction function that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
"""
output_var_names = config.output_var_names
# input/output variables
# build graph
input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False)
if config.input_data_mapping is None:
input_map = input_vars
else:
input_map = [input_vars[k] for k in config.input_data_mapping if k >= 0]
# check output_var_names against output_vars
output_vars = get_vars_by_names(output_var_names)
input_vars = get_vars_by_names(config.input_var_names)
output_vars = get_vars_by_names(config.output_var_names)
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
def run_input(dp):
feed = dict(zip(input_map, dp))
assert len(input_vars) == len(dp), "{} != {}".format(len(input_vars), len(dp))
feed = dict(zip(input_vars, dp))
return sess.run(output_vars, feed_dict=feed)
# XXX hack. so the caller can get access to the session.
run_input.session = sess
......
......@@ -96,7 +96,7 @@ class PredictorWorkerThread(threading.Thread):
while True:
batched, futures = self.fetch_batch()
outputs = self.func(batched)
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
#print "batched size: ", len(batched[0]), "queuesize: ", self.queue.qsize()
# debug, for speed testing
#if self.xxx is None:
#self.xxx = outputs = self.func([batched])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment