Commit db1caa42 authored by Yuxin Wu's avatar Yuxin Wu

fix PredictConfig. use inputs_desc adn tower_func for load-cpm

parent bbb2ecc2
......@@ -44,14 +44,7 @@ def get_gaussian_map():
return gaussian_map.reshape((1, 368, 368, 1))
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, (None, 368, 368, 3), 'input'),
InputDesc(tf.float32, (None, 368, 368, 15), 'label'),
]
def _build_graph(self, inputs):
image, label = inputs
def CPM(image):
image = image / 256.0 - 0.5
gmap = tf.constant(get_gaussian_map())
......@@ -108,7 +101,8 @@ class Model(ModelDesc):
def run_test(model_path, img_file):
param_dict = np.load(model_path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig(
model=Model(),
inputs_desc=[InputDesc(tf.float32, (None, 368, 368, 3), 'input')],
tower_func=CPM,
session_init=DictRestore(param_dict),
input_names=['input'],
output_names=['resized_map']
......
......@@ -27,7 +27,7 @@ class PredictConfig(object):
):
"""
Args:
model (ModelDescBase): the model to be used to obtain inputs_desc and tower_func.
model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors
......@@ -35,8 +35,7 @@ class PredictConfig(object):
session. Defaults to :class:`tf.train.ChiefSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all
inputs of the model.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
return_input (bool): same as in :attr:`PredictorBase.return_input`.
......@@ -70,8 +69,7 @@ class PredictConfig(object):
# inputs & outputs
self.input_names = input_names
if self.input_names is None:
raw_tensors = self.model.get_inputs_desc()
self.input_names = [k.name for k in raw_tensors]
self.input_names = [k.name for k in self.inputs_desc]
self.output_names = output_names
assert_type(self.output_names, list)
assert_type(self.input_names, list)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment