Commit 1dcc0e72 authored by Yuxin Wu's avatar Yuxin Wu

require output_var_name in predict.

parent b7766fc1
...@@ -78,11 +78,8 @@ def get_predict_func(config): ...@@ -78,11 +78,8 @@ def get_predict_func(config):
input_map = [input_vars[k] for k in config.input_data_mapping] input_map = [input_vars[k] for k in config.input_data_mapping]
# check output_var_names against output_vars # check output_var_names against output_vars
if output_var_names is not None: output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1]) for n in output_var_names]
for n in output_var_names]
else:
output_vars = []
describe_model() describe_model()
...@@ -94,31 +91,29 @@ def get_predict_func(config): ...@@ -94,31 +91,29 @@ def get_predict_func(config):
"Graph has {} inputs but dataset only gives {} components!".format( "Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp)) len(input_map), len(dp))
feed = dict(zip(input_map, dp)) feed = dict(zip(input_map, dp))
if output_var_names is not None:
results = sess.run(output_vars, feed_dict=feed) results = sess.run(output_vars, feed_dict=feed)
return results if len(output_vars) == 1:
return results[0]
else: else:
results = sess.run([cost_var], feed_dict=feed) return results
cost = results[0]
return cost
return run_input return run_input
PredictResult = namedtuple('PredictResult', ['input', 'output']) PredictResult = namedtuple('PredictResult', ['input', 'output'])
# TODO mutligpu predictor
class DatasetPredictor(object): class DatasetPredictor(object):
""" """
Run the predict_config on a given `DataFlow`. Run the predict_config on a given `DataFlow`.
""" """
def __init__(self, predict_config, dataset, batch=0): def __init__(self, predict_config, dataset):
""" """
:param predict_config: a `PredictConfig` instance. :param predict_config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance. :param dataset: a `DataFlow` instance.
:param batch: if batch > zero, will batch the dataset before running.
""" """
assert isinstance(dataset, DataFlow) assert isinstance(dataset, DataFlow)
self.ds = dataset self.ds = dataset
if batch > 0:
self.ds = BatchData(self.ds, batch, remainder=True)
self.predict_func = get_predict_func(predict_config) self.predict_func = get_predict_func(predict_config)
def get_result(self): def get_result(self):
...@@ -133,3 +128,4 @@ class DatasetPredictor(object): ...@@ -133,3 +128,4 @@ class DatasetPredictor(object):
Run over the dataset and return a list of all predictions. Run over the dataset and return a list of all predictions.
""" """
return list(self.get_result()) return list(self.get_result())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment