Commit c5df0501 authored by ppwwyyxx's avatar ppwwyyxx

input_dataset_mapping for prediction

parent 9bb0b8f6
...@@ -135,6 +135,7 @@ def run_test(path): ...@@ -135,6 +135,7 @@ def run_test(path):
pred_config = PredictConfig( pred_config = PredictConfig(
inputs=input_vars, inputs=input_vars,
input_dataset_mapping=[input_vars[0]],
get_model_func=get_model, get_model_func=get_model,
session_init=ParamRestore(param_dict), session_init=ParamRestore(param_dict),
output_var_names=['output:0'] # output:0 is the probability distribution output_var_names=['output:0'] # output:0 is the probability distribution
...@@ -146,7 +147,7 @@ def run_test(path): ...@@ -146,7 +147,7 @@ def run_test(path):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (227, 227)) im = cv2.resize(im, (227, 227))
im = np.reshape(im, (1, 227, 227, 3)) im = np.reshape(im, (1, 227, 227, 3))
outputs = predict_func([im, (1,)])[0] outputs = predict_func([im])[0]
prob = outputs[0] prob = outputs[0]
print prob.shape print prob.shape
print prob.argsort()[-10:][::-1] print prob.argsort()[-10:][::-1]
......
...@@ -22,8 +22,22 @@ class PredictConfig(object): ...@@ -22,8 +22,22 @@ class PredictConfig(object):
session. default to a session running 1 GPU. session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. initialize variables of a session.
inputs: a list of input variables. must match the dataset later inputs: input variables of the graph.
used for prediction. input_dataset_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example
the `label` input is not used if you only need probability
distribution). It should be a list with size=len(one_data_point),
where each element is a tensor which each component of the
data point should be fed into.
If not given, defaults to `inputs`.
For example, with image classification task, the testing
dataset only provides datapoints of images (no labels). The
arguments should look like:
inputs: [image_var, label_var]
input_dataset_mapping: [image_var]
If this argument is not set, the inputs and the data points won't be aligned.
get_model_func: a function taking `inputs` and `is_training` and get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize return a tuple of output list as well as the cost to minimize
output_var_names: a list of names of the output variable to predict, the output_var_names: a list of names of the output variable to predict, the
...@@ -38,7 +52,7 @@ class PredictConfig(object): ...@@ -38,7 +52,7 @@ class PredictConfig(object):
assert_type(self.session_config, tf.ConfigProto) assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init') self.session_init = kwargs.pop('session_init')
self.inputs = kwargs.pop('inputs') self.inputs = kwargs.pop('inputs')
[assert_type(i, tf.Tensor) for i in self.inputs] self.input_dataset_mapping = kwargs.pop('input_dataset_mapping', None)
self.get_model_func = kwargs.pop('get_model_func') self.get_model_func = kwargs.pop('get_model_func')
self.output_var_names = kwargs.pop('output_var_names', None) self.output_var_names = kwargs.pop('output_var_names', None)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
...@@ -59,6 +73,9 @@ def get_predict_func(config): ...@@ -59,6 +73,9 @@ def get_predict_func(config):
# input/output variables # input/output variables
input_vars = config.inputs input_vars = config.inputs
output_vars, cost_var = config.get_model_func(input_vars, is_training=False) output_vars, cost_var = config.get_model_func(input_vars, is_training=False)
input_map = config.input_dataset_mapping
if input_map is None:
input_map = input_vars
# check output_var_names against output_vars # check output_var_names against output_vars
if output_var_names is not None: if output_var_names is not None:
...@@ -70,8 +87,10 @@ def get_predict_func(config): ...@@ -70,8 +87,10 @@ def get_predict_func(config):
config.session_init.init(sess) config.session_init.init(sess)
def run_input(dp): def run_input(dp):
# TODO if input and dp not aligned? assert len(input_map) == len(dp), \
feed = dict(zip(input_vars, dp)) "Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp))
feed = dict(zip(input_map, dp))
if output_var_names is not None: if output_var_names is not None:
results = sess.run(output_vars, feed_dict=feed) results = sess.run(output_vars, feed_dict=feed)
return results return results
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment