Commit d5410902 authored by Yuxin Wu's avatar Yuxin Wu

use input_names in predictconfig

parent cba97f75
...@@ -196,6 +196,7 @@ if __name__ == '__main__': ...@@ -196,6 +196,7 @@ if __name__ == '__main__':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=SaverRestore(args.load),
input_var_names=['state']
output_var_names=['fct/output:0']) output_var_names=['fct/output:0'])
if args.task == 'play': if args.task == 'play':
play_model(cfg) play_model(cfg)
......
...@@ -9,7 +9,7 @@ from tqdm import tqdm ...@@ -9,7 +9,7 @@ from tqdm import tqdm
from six.moves import queue from six.moves import queue
from tensorpack import * from tensorpack import *
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker from tensorpack.predict import get_predict_func
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.utils.stat import * from tensorpack.utils.stat import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
......
...@@ -104,7 +104,7 @@ def eval_on_ILSVRC12(model, sess_init, data_dir): ...@@ -104,7 +104,7 @@ def eval_on_ILSVRC12(model, sess_init, data_dir):
def run_test(model, sess_init, inputs): def run_test(model, sess_init, inputs):
pred_config = PredictConfig( pred_config = PredictConfig(
model=model, model=model,
input_data_mapping=[0], input_var_names=['input'],
session_init=sess_init, session_init=sess_init,
session_config=get_default_sess_config(0.9), session_config=get_default_sess_config(0.9),
output_var_names=['prob:0'] output_var_names=['prob:0']
......
...@@ -59,7 +59,7 @@ def run_test(path, input): ...@@ -59,7 +59,7 @@ def run_test(path, input):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
input_data_mapping=[0], input_var_names=['input'],
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9), session_config=get_default_sess_config(0.9),
output_var_names=['output:0'] # output:0 is the probability distribution output_var_names=['output:0'] # output:0 is the probability distribution
......
...@@ -76,7 +76,7 @@ def run_test(path, input): ...@@ -76,7 +76,7 @@ def run_test(path, input):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
input_data_mapping=[0], input_var_names=['input'],
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.9), session_config=get_default_sess_config(0.9),
output_var_names=['output:0'] # output:0 is the probability distribution output_var_names=['output:0'] # output:0 is the probability distribution
......
...@@ -30,10 +30,10 @@ get_config_func = imp.load_source('config_script', args.config).get_config ...@@ -30,10 +30,10 @@ get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G: with tf.Graph().as_default() as G:
train_config = get_config_func() train_config = get_config_func()
M = train_config.model
config = PredictConfig( config = PredictConfig(
inputs=train_config.inputs, input_var_names=[M.get_input_vars_desc()[0].name], # assume first component is image
input_dataset_mapping=[train_config.inputs[0]], # assume first component is image model=M,
get_model_func=train_config.get_model_func,
session_init=sessinit.SaverRestore(args.model), session_init=sessinit.SaverRestore(args.model),
output_var_names=['output:0'] output_var_names=['output:0']
) )
......
...@@ -42,6 +42,10 @@ class ModelDesc(object): ...@@ -42,6 +42,10 @@ class ModelDesc(object):
g = tf.get_default_graph() g = tf.get_default_graph()
return [g.get_tensor_by_name(name + ":0") for name in input_var_names] return [g.get_tensor_by_name(name + ":0") for name in input_var_names]
def get_input_vars_desc(self):
""" return a list of `InputVar` instance"""
return self._get_input_vars()
@abstractmethod @abstractmethod
def _get_input_vars(self): def _get_input_vars(self):
""":returns: a list of InputVar """ """:returns: a list of InputVar """
......
...@@ -7,6 +7,7 @@ from collections import namedtuple ...@@ -7,6 +7,7 @@ from collections import namedtuple
from six.moves import zip from six.moves import zip
from tensorpack.models import ModelDesc from tensorpack.models import ModelDesc
from ..utils import logger
from ..tfutils import * from ..tfutils import *
import multiprocessing import multiprocessing
...@@ -22,26 +23,8 @@ class PredictConfig(object): ...@@ -22,26 +23,8 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to :param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session. initialize variables of a session.
:param input_data_mapping: Decide the mapping from each component in data :param input_var_names: a list of input variable names.
to the input tensor, since you may not need all input variables :param input_data_mapping: deprecated. used to select `input_var_names` from the `InputVars` of the model.
of the Model to run the graph for prediction (for example
the `label` input is not used if you only need probability distribution).
It should be a list of int with length equal to `len(data_point)`,
where each element in the list defines which input variables each
component in the data point should be fed into.
If not given, defaults to range(len(input_vars))
For example, in image classification task, the testing
dataset only provides datapoints of images (no labels). When
the input variables of the model is: ::
input_vars: [image_var, label_var]
the mapping should then look like: ::
input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
:param model: a `ModelDesc` instance :param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the :param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
...@@ -58,8 +41,21 @@ class PredictConfig(object): ...@@ -58,8 +41,21 @@ class PredictConfig(object):
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model') self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
self.input_data_mapping = kwargs.pop('input_data_mapping', None) self.input_var_names = kwargs.pop('input_var_names', None)
input_mapping = kwargs.pop('input_data_mapping', None)
if input_mapping:
raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [raw_vars[k].name for k in input_mapping]
logger.warn('The option `input_data_mapping` was deprecated. \
Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
elif self.input_var_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [k.name for k in raw_vars]
self.output_var_names = kwargs.pop('output_var_names') self.output_var_names = kwargs.pop('output_var_names')
assert len(self.input_var_names), self.input_var_names
assert len(self.output_var_names), self.output_var_names
self.return_input = kwargs.pop('return_input', False) self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
...@@ -71,24 +67,19 @@ def get_predict_func(config): ...@@ -71,24 +67,19 @@ def get_predict_func(config):
:returns: A prediction function that takes a list of input values, and return :returns: A prediction function that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``. a list of output values defined in ``config.output_var_names``.
""" """
output_var_names = config.output_var_names # build graph
# input/output variables
input_vars = config.model.get_input_vars() input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False) config.model._build_graph(input_vars, False)
if config.input_data_mapping is None:
input_map = input_vars
else:
input_map = [input_vars[k] for k in config.input_data_mapping if k >= 0]
# check output_var_names against output_vars input_vars = get_vars_by_names(config.input_var_names)
output_vars = get_vars_by_names(output_var_names) output_vars = get_vars_by_names(config.output_var_names)
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
def run_input(dp): def run_input(dp):
feed = dict(zip(input_map, dp)) assert len(input_vars) == len(dp), "{} != {}".format(len(input_vars), len(dp))
feed = dict(zip(input_vars, dp))
return sess.run(output_vars, feed_dict=feed) return sess.run(output_vars, feed_dict=feed)
# XXX hack. so the caller can get access to the session. # XXX hack. so the caller can get access to the session.
run_input.session = sess run_input.session = sess
......
...@@ -96,7 +96,7 @@ class PredictorWorkerThread(threading.Thread): ...@@ -96,7 +96,7 @@ class PredictorWorkerThread(threading.Thread):
while True: while True:
batched, futures = self.fetch_batch() batched, futures = self.fetch_batch()
outputs = self.func(batched) outputs = self.func(batched)
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize() #print "batched size: ", len(batched[0]), "queuesize: ", self.queue.qsize()
# debug, for speed testing # debug, for speed testing
#if self.xxx is None: #if self.xxx is None:
#self.xxx = outputs = self.func([batched]) #self.xxx = outputs = self.func([batched])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment