Commit c5df0501 authored by ppwwyyxx's avatar ppwwyyxx

input_dataset_mapping for prediction

parent 9bb0b8f6
......@@ -135,6 +135,7 @@ def run_test(path):
pred_config = PredictConfig(
inputs=input_vars,
input_dataset_mapping=[input_vars[0]],
get_model_func=get_model,
session_init=ParamRestore(param_dict),
output_var_names=['output:0'] # output:0 is the probability distribution
......@@ -146,7 +147,7 @@ def run_test(path):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (227, 227))
im = np.reshape(im, (1, 227, 227, 3))
outputs = predict_func([im, (1,)])[0]
outputs = predict_func([im])[0]
prob = outputs[0]
print prob.shape
print prob.argsort()[-10:][::-1]
......
......@@ -22,8 +22,22 @@ class PredictConfig(object):
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session.
inputs: a list of input variables. must match the dataset later
used for prediction.
inputs: input variables of the graph.
input_dataset_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example
the `label` input is not used if you only need probability
distribution). It should be a list with size=len(one_data_point),
where each element is a tensor which each component of the
data point should be fed into.
If not given, defaults to `inputs`.
For example, with image classification task, the testing
dataset only provides datapoints of images (no labels). The
arguments should look like:
inputs: [image_var, label_var]
input_dataset_mapping: [image_var]
If this argument is not set, the inputs and the data points won't be aligned.
get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize
output_var_names: a list of names of the output variable to predict, the
......@@ -38,7 +52,7 @@ class PredictConfig(object):
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init')
self.inputs = kwargs.pop('inputs')
[assert_type(i, tf.Tensor) for i in self.inputs]
self.input_dataset_mapping = kwargs.pop('input_dataset_mapping', None)
self.get_model_func = kwargs.pop('get_model_func')
self.output_var_names = kwargs.pop('output_var_names', None)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......@@ -59,6 +73,9 @@ def get_predict_func(config):
# input/output variables
input_vars = config.inputs
output_vars, cost_var = config.get_model_func(input_vars, is_training=False)
input_map = config.input_dataset_mapping
if input_map is None:
input_map = input_vars
# check output_var_names against output_vars
if output_var_names is not None:
......@@ -70,8 +87,10 @@ def get_predict_func(config):
config.session_init.init(sess)
def run_input(dp):
# TODO if input and dp not aligned?
feed = dict(zip(input_vars, dp))
assert len(input_map) == len(dp), \
"Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp))
feed = dict(zip(input_map, dp))
if output_var_names is not None:
results = sess.run(output_vars, feed_dict=feed)
return results
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment