Commit db1caa42 authored by Yuxin Wu's avatar Yuxin Wu

fix PredictConfig. use inputs_desc adn tower_func for load-cpm

parent bbb2ecc2
...@@ -44,71 +44,65 @@ def get_gaussian_map(): ...@@ -44,71 +44,65 @@ def get_gaussian_map():
return gaussian_map.reshape((1, 368, 368, 1)) return gaussian_map.reshape((1, 368, 368, 1))
class Model(ModelDesc): def CPM(image):
def _get_inputs(self): image = image / 256.0 - 0.5
return [InputDesc(tf.float32, (None, 368, 368, 3), 'input'),
InputDesc(tf.float32, (None, 368, 368, 15), 'label'), gmap = tf.constant(get_gaussian_map())
] gmap = tf.pad(gmap, [[0, 0], [0, 1], [0, 1], [0, 0]])
pool_center = AvgPooling('mappool', gmap, 9, stride=8, padding='VALID')
def _build_graph(self, inputs): with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu,
image, label = inputs W_init=tf.random_normal_initializer(stddev=0.01)):
image = image / 256.0 - 0.5 shared = (LinearWrap(image)
.Conv2D('conv1_1', 64)
gmap = tf.constant(get_gaussian_map()) .Conv2D('conv1_2', 64)
gmap = tf.pad(gmap, [[0, 0], [0, 1], [0, 1], [0, 0]]) .MaxPooling('pool1', 2)
pool_center = AvgPooling('mappool', gmap, 9, stride=8, padding='VALID') # 184
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, .Conv2D('conv2_1', 128)
W_init=tf.random_normal_initializer(stddev=0.01)): .Conv2D('conv2_2', 128)
shared = (LinearWrap(image) .MaxPooling('pool2', 2)
.Conv2D('conv1_1', 64) # 92
.Conv2D('conv1_2', 64) .Conv2D('conv3_1', 256)
.MaxPooling('pool1', 2) .Conv2D('conv3_2', 256)
# 184 .Conv2D('conv3_3', 256)
.Conv2D('conv2_1', 128) .Conv2D('conv3_4', 256)
.Conv2D('conv2_2', 128) .MaxPooling('pool3', 2)
.MaxPooling('pool2', 2) # 46
# 92 .Conv2D('conv4_1', 512)
.Conv2D('conv3_1', 256) .Conv2D('conv4_2', 512)
.Conv2D('conv3_2', 256) .Conv2D('conv4_3_CPM', 256)
.Conv2D('conv3_3', 256) .Conv2D('conv4_4_CPM', 256)
.Conv2D('conv3_4', 256) .Conv2D('conv4_5_CPM', 256)
.MaxPooling('pool3', 2) .Conv2D('conv4_6_CPM', 256)
# 46 .Conv2D('conv4_7_CPM', 128)())
.Conv2D('conv4_1', 512)
.Conv2D('conv4_2', 512) def add_stage(stage, l):
.Conv2D('conv4_3_CPM', 256) l = tf.concat([l, shared, pool_center], 3,
.Conv2D('conv4_4_CPM', 256) name='concat_stage{}'.format(stage))
.Conv2D('conv4_5_CPM', 256) for i in range(1, 6):
.Conv2D('conv4_6_CPM', 256) l = Conv2D('Mconv{}_stage{}'.format(i, stage), l, 128)
.Conv2D('conv4_7_CPM', 128)()) l = Conv2D('Mconv6_stage{}'.format(stage), l, 128, kernel_shape=1)
l = Conv2D('Mconv7_stage{}'.format(stage),
def add_stage(stage, l): l, 15, kernel_shape=1, nl=tf.identity)
l = tf.concat([l, shared, pool_center], 3, return l
name='concat_stage{}'.format(stage))
for i in range(1, 6): with argscope(Conv2D, kernel_shape=7, nl=tf.nn.relu):
l = Conv2D('Mconv{}_stage{}'.format(i, stage), l, 128) out1 = (LinearWrap(shared)
l = Conv2D('Mconv6_stage{}'.format(stage), l, 128, kernel_shape=1) .Conv2D('conv5_1_CPM', 512, kernel_shape=1)
l = Conv2D('Mconv7_stage{}'.format(stage), .Conv2D('conv5_2_CPM', 15, kernel_shape=1, nl=tf.identity)())
l, 15, kernel_shape=1, nl=tf.identity) out2 = add_stage(2, out1)
return l out3 = add_stage(3, out2)
out4 = add_stage(4, out3)
with argscope(Conv2D, kernel_shape=7, nl=tf.nn.relu): out5 = add_stage(5, out4)
out1 = (LinearWrap(shared) out6 = add_stage(6, out4)
.Conv2D('conv5_1_CPM', 512, kernel_shape=1) resized_map = tf.image.resize_bilinear(out6,
.Conv2D('conv5_2_CPM', 15, kernel_shape=1, nl=tf.identity)()) [368, 368], name='resized_map')
out2 = add_stage(2, out1)
out3 = add_stage(3, out2)
out4 = add_stage(4, out3)
out5 = add_stage(5, out4)
out6 = add_stage(6, out4)
resized_map = tf.image.resize_bilinear(out6,
[368, 368], name='resized_map')
def run_test(model_path, img_file): def run_test(model_path, img_file):
param_dict = np.load(model_path, encoding='latin1').item() param_dict = np.load(model_path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
model=Model(), inputs_desc=[InputDesc(tf.float32, (None, 368, 368, 3), 'input')],
tower_func=CPM,
session_init=DictRestore(param_dict), session_init=DictRestore(param_dict),
input_names=['input'], input_names=['input'],
output_names=['resized_map'] output_names=['resized_map']
......
...@@ -27,7 +27,7 @@ class PredictConfig(object): ...@@ -27,7 +27,7 @@ class PredictConfig(object):
): ):
""" """
Args: Args:
model (ModelDescBase): the model to be used to obtain inputs_desc and tower_func. model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]): inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors tower_func: a callable which takes input tensors
...@@ -35,8 +35,7 @@ class PredictConfig(object): ...@@ -35,8 +35,7 @@ class PredictConfig(object):
session. Defaults to :class:`tf.train.ChiefSessionCreator()`. session. Defaults to :class:`tf.train.ChiefSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session. session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing. Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all input_names (list): a list of input tensor names. Defaults to match inputs_desc.
inputs of the model.
output_names (list): a list of names of the output tensors to predict, the output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph. tensors can be any computable tensor in the graph.
return_input (bool): same as in :attr:`PredictorBase.return_input`. return_input (bool): same as in :attr:`PredictorBase.return_input`.
...@@ -70,8 +69,7 @@ class PredictConfig(object): ...@@ -70,8 +69,7 @@ class PredictConfig(object):
# inputs & outputs # inputs & outputs
self.input_names = input_names self.input_names = input_names
if self.input_names is None: if self.input_names is None:
raw_tensors = self.model.get_inputs_desc() self.input_names = [k.name for k in self.inputs_desc]
self.input_names = [k.name for k in raw_tensors]
self.output_names = output_names self.output_names = output_names
assert_type(self.output_names, list) assert_type(self.output_names, list)
assert_type(self.input_names, list) assert_type(self.input_names, list)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment