Commit 4f1efe74 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] let fasterrcnn predict #class boxes instead of #category

parent 16581e74
...@@ -252,15 +252,15 @@ def fastrcnn_outputs(feature, num_classes): ...@@ -252,15 +252,15 @@ def fastrcnn_outputs(feature, num_classes):
num_classes(int): num_category + 1 num_classes(int): num_category + 1
Returns: Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4) cls_logits (Nxnum_class), reg_logits (Nx num_class x 4)
""" """
classification = FullyConnected( classification = FullyConnected(
'class', feature, num_classes, 'class', feature, num_classes,
kernel_initializer=tf.random_normal_initializer(stddev=0.01)) kernel_initializer=tf.random_normal_initializer(stddev=0.01))
box_regression = FullyConnected( box_regression = FullyConnected(
'box', feature, (num_classes - 1) * 4, 'box', feature, num_classes * 4,
kernel_initializer=tf.random_normal_initializer(stddev=0.001)) kernel_initializer=tf.random_normal_initializer(stddev=0.001))
box_regression = tf.reshape(box_regression, (-1, num_classes - 1, 4)) box_regression = tf.reshape(box_regression, (-1, num_classes, 4))
return classification, box_regression return classification, box_regression
...@@ -314,7 +314,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -314,7 +314,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
labels: n, labels: n,
label_logits: nxC label_logits: nxC
fg_boxes: nfgx4, encoded fg_boxes: nfgx4, encoded
fg_box_logits: nfgx(C-1)x4 fg_box_logits: nfgxCx4
""" """
label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=label_logits) labels=labels, logits=label_logits)
...@@ -325,7 +325,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -325,7 +325,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
num_fg = tf.size(fg_inds) num_fg = tf.size(fg_inds)
indices = tf.stack( indices = tf.stack(
[tf.range(num_fg), [tf.range(num_fg),
tf.to_int32(fg_labels) - 1], axis=1) # #fgx2 tf.to_int32(fg_labels)], axis=1) # #fgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices) fg_box_logits = tf.gather_nd(fg_box_logits, indices)
with tf.name_scope('label_metrics'), tf.device('/cpu:0'): with tf.name_scope('label_metrics'), tf.device('/cpu:0'):
......
...@@ -131,6 +131,8 @@ class DetectionModel(ModelDesc): ...@@ -131,6 +131,8 @@ class DetectionModel(ModelDesc):
boxes (mx4): boxes (mx4):
labels (m): each >= 1 labels (m): each >= 1
""" """
rcnn_box_logits = rcnn_box_logits[:, 1:, :]
rcnn_box_logits.set_shape([None, config.NUM_CLASS - 1, None])
label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class
anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4 anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4
decoded_boxes = decode_bbox_target( decoded_boxes = decode_bbox_target(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment