Commit c7158c84 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] bugfix in RPN.CROWD_OVERLAP_THRESH

parent 4b25b6fa
......@@ -160,7 +160,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
cand_inds = np.where(anchor_labels >= 0)[0]
cand_anchors = anchors[cand_inds]
ioas = np_ioa(crowd_boxes, cand_anchors)
overlap_with_crowd = cand_inds[ioas.max(axis=0) > cfg.RPN.CROWD_OVERLAP_THRES]
overlap_with_crowd = cand_inds[ioas.max(axis=0) > cfg.RPN.CROWD_OVERLAP_THRESH]
anchor_labels[overlap_with_crowd] = -1
# Subsample fg labels: ignore some fg if fg is too many
......
......@@ -165,22 +165,19 @@ def multilevel_rpn_losses(
@under_name_scope()
def generate_fpn_proposals(
multilevel_anchors, multilevel_label_logits,
multilevel_box_logits, image_shape2d):
multilevel_pred_boxes, multilevel_label_logits, image_shape2d):
"""
Args:
multilevel_anchors: #lvl RPNAnchors
multilevel_pred_boxes: #lvl HxWxAx4 boxes
multilevel_label_logits: #lvl tensors of shape HxWxA
multilevel_box_logits: #lvl tensors of shape HxWxAx4
Returns:
boxes: kx4 float
scores: k logits
"""
num_lvl = len(cfg.FPN.ANCHOR_STRIDES)
assert len(multilevel_anchors) == num_lvl
assert len(multilevel_pred_boxes) == num_lvl
assert len(multilevel_label_logits) == num_lvl
assert len(multilevel_box_logits) == num_lvl
ctx = get_current_tower_context()
all_boxes = []
......@@ -189,8 +186,7 @@ def generate_fpn_proposals(
fpn_nms_topk = cfg.RPN.TRAIN_PER_LEVEL_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PER_LEVEL_NMS_TOPK
for lvl in range(num_lvl):
with tf.name_scope('Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl]
pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl])
pred_boxes_decoded = multilevel_pred_boxes[lvl]
proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(pred_boxes_decoded, [-1, 4]),
......@@ -207,8 +203,7 @@ def generate_fpn_proposals(
else:
for lvl in range(num_lvl):
with tf.name_scope('Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl]
pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl])
pred_boxes_decoded = multilevel_pred_boxes[lvl]
all_boxes.append(tf.reshape(pred_boxes_decoded, [-1, 4]))
all_scores.append(tf.reshape(multilevel_label_logits[lvl], [-1]))
all_boxes = tf.concat(all_boxes, axis=0)
......
......@@ -249,10 +249,11 @@ class ResNetFPNModel(DetectionModel):
for pi in p23456]
multilevel_label_logits = [k[0] for k in rpn_outputs]
multilevel_box_logits = [k[1] for k in rpn_outputs]
multilevel_pred_boxes = [anchor.decode_logits(logits)
for anchor, logits in zip(multilevel_anchors, multilevel_box_logits)]
proposal_boxes, proposal_scores = generate_fpn_proposals(
multilevel_anchors, multilevel_label_logits,
multilevel_box_logits, image_shape2d)
multilevel_pred_boxes, multilevel_label_logits, image_shape2d)
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels']
if is_training:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment