Commit c7158c84 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] bugfix in RPN.CROWD_OVERLAP_THRESH

parent 4b25b6fa
...@@ -160,7 +160,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes): ...@@ -160,7 +160,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
cand_inds = np.where(anchor_labels >= 0)[0] cand_inds = np.where(anchor_labels >= 0)[0]
cand_anchors = anchors[cand_inds] cand_anchors = anchors[cand_inds]
ioas = np_ioa(crowd_boxes, cand_anchors) ioas = np_ioa(crowd_boxes, cand_anchors)
overlap_with_crowd = cand_inds[ioas.max(axis=0) > cfg.RPN.CROWD_OVERLAP_THRES] overlap_with_crowd = cand_inds[ioas.max(axis=0) > cfg.RPN.CROWD_OVERLAP_THRESH]
anchor_labels[overlap_with_crowd] = -1 anchor_labels[overlap_with_crowd] = -1
# Subsample fg labels: ignore some fg if fg is too many # Subsample fg labels: ignore some fg if fg is too many
......
...@@ -165,22 +165,19 @@ def multilevel_rpn_losses( ...@@ -165,22 +165,19 @@ def multilevel_rpn_losses(
@under_name_scope() @under_name_scope()
def generate_fpn_proposals( def generate_fpn_proposals(
multilevel_anchors, multilevel_label_logits, multilevel_pred_boxes, multilevel_label_logits, image_shape2d):
multilevel_box_logits, image_shape2d):
""" """
Args: Args:
multilevel_anchors: #lvl RPNAnchors multilevel_pred_boxes: #lvl HxWxAx4 boxes
multilevel_label_logits: #lvl tensors of shape HxWxA multilevel_label_logits: #lvl tensors of shape HxWxA
multilevel_box_logits: #lvl tensors of shape HxWxAx4
Returns: Returns:
boxes: kx4 float boxes: kx4 float
scores: k logits scores: k logits
""" """
num_lvl = len(cfg.FPN.ANCHOR_STRIDES) num_lvl = len(cfg.FPN.ANCHOR_STRIDES)
assert len(multilevel_anchors) == num_lvl assert len(multilevel_pred_boxes) == num_lvl
assert len(multilevel_label_logits) == num_lvl assert len(multilevel_label_logits) == num_lvl
assert len(multilevel_box_logits) == num_lvl
ctx = get_current_tower_context() ctx = get_current_tower_context()
all_boxes = [] all_boxes = []
...@@ -189,8 +186,7 @@ def generate_fpn_proposals( ...@@ -189,8 +186,7 @@ def generate_fpn_proposals(
fpn_nms_topk = cfg.RPN.TRAIN_PER_LEVEL_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PER_LEVEL_NMS_TOPK fpn_nms_topk = cfg.RPN.TRAIN_PER_LEVEL_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PER_LEVEL_NMS_TOPK
for lvl in range(num_lvl): for lvl in range(num_lvl):
with tf.name_scope('Lvl{}'.format(lvl + 2)): with tf.name_scope('Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl] pred_boxes_decoded = multilevel_pred_boxes[lvl]
pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl])
proposal_boxes, proposal_scores = generate_rpn_proposals( proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(pred_boxes_decoded, [-1, 4]),
...@@ -207,8 +203,7 @@ def generate_fpn_proposals( ...@@ -207,8 +203,7 @@ def generate_fpn_proposals(
else: else:
for lvl in range(num_lvl): for lvl in range(num_lvl):
with tf.name_scope('Lvl{}'.format(lvl + 2)): with tf.name_scope('Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl] pred_boxes_decoded = multilevel_pred_boxes[lvl]
pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl])
all_boxes.append(tf.reshape(pred_boxes_decoded, [-1, 4])) all_boxes.append(tf.reshape(pred_boxes_decoded, [-1, 4]))
all_scores.append(tf.reshape(multilevel_label_logits[lvl], [-1])) all_scores.append(tf.reshape(multilevel_label_logits[lvl], [-1]))
all_boxes = tf.concat(all_boxes, axis=0) all_boxes = tf.concat(all_boxes, axis=0)
......
...@@ -249,10 +249,11 @@ class ResNetFPNModel(DetectionModel): ...@@ -249,10 +249,11 @@ class ResNetFPNModel(DetectionModel):
for pi in p23456] for pi in p23456]
multilevel_label_logits = [k[0] for k in rpn_outputs] multilevel_label_logits = [k[0] for k in rpn_outputs]
multilevel_box_logits = [k[1] for k in rpn_outputs] multilevel_box_logits = [k[1] for k in rpn_outputs]
multilevel_pred_boxes = [anchor.decode_logits(logits)
for anchor, logits in zip(multilevel_anchors, multilevel_box_logits)]
proposal_boxes, proposal_scores = generate_fpn_proposals( proposal_boxes, proposal_scores = generate_fpn_proposals(
multilevel_anchors, multilevel_label_logits, multilevel_pred_boxes, multilevel_label_logits, image_shape2d)
multilevel_box_logits, image_shape2d)
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels'] gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels']
if is_training: if is_training:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment