Commit 70b70736 authored by Yuxin Wu's avatar Yuxin Wu

[MaskCNN] move fpn-rpn and fpn-proposals to models

parent 8d2febb7
...@@ -107,6 +107,7 @@ _C.RPN.BATCH_PER_IM = 256 # total (across FPN levels) number of anchors that ar ...@@ -107,6 +107,7 @@ _C.RPN.BATCH_PER_IM = 256 # total (across FPN levels) number of anchors that ar
_C.RPN.MIN_SIZE = 0 _C.RPN.MIN_SIZE = 0
_C.RPN.PROPOSAL_NMS_THRESH = 0.7 _C.RPN.PROPOSAL_NMS_THRESH = 0.7
_C.RPN.CROWD_OVERLAP_THRES = 0.7 # boxes overlapping crowd will be ignored. _C.RPN.CROWD_OVERLAP_THRES = 0.7 # boxes overlapping crowd will be ignored.
_C.RPN.HEAD_DIM = 1024 # used in C4 only
# RPN proposal selection ------------------------------- # RPN proposal selection -------------------------------
# for C4 # for C4
......
...@@ -569,3 +569,75 @@ def multilevel_roi_align(features, rcnn_boxes, resolution): ...@@ -569,3 +569,75 @@ def multilevel_roi_align(features, rcnn_boxes, resolution):
level_id_invert_perm = tf.invert_permutation(level_id_perm) level_id_invert_perm = tf.invert_permutation(level_id_perm)
all_rois = tf.gather(all_rois, level_id_invert_perm) all_rois = tf.gather(all_rois, level_id_invert_perm)
return all_rois return all_rois
def multilevel_rpn_losses(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits):
"""
Args:
multilevel_anchors: #lvl RPNAnchors
multilevel_label_logits: #lvl tensors of shape HxWxA
multilevel_box_logits: #lvl tensors of shape HxWxAx4
Returns:
label_loss, box_loss
"""
num_lvl = len(cfg.FPN.ANCHOR_STRIDES)
assert len(multilevel_anchors) == num_lvl
assert len(multilevel_label_logits) == num_lvl
assert len(multilevel_box_logits) == num_lvl
losses = []
for lvl in range(num_lvl):
with tf.name_scope('RPNLoss_Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl]
label_loss, box_loss = rpn_losses(
anchors.gt_labels, anchors.encoded_gt_boxes(),
multilevel_label_logits[lvl], multilevel_box_logits[lvl])
losses.extend([label_loss, box_loss])
with tf.name_scope('rpn_losses'):
total_label_loss = tf.add_n(losses[::2], name='label_loss')
total_box_loss = tf.add_n(losses[1::2], name='box_loss')
add_moving_summary(total_label_loss, total_box_loss)
return total_label_loss, total_box_loss
def generate_fpn_proposals(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits,
image_shape2d, pre_nms_topk, post_nms_topk):
"""
Args:
multilevel_anchors: #lvl RPNAnchors
multilevel_label_logits: #lvl tensors of shape HxWxA
multilevel_box_logits: #lvl tensors of shape HxWxAx4
Returns:
boxes: kx4 float
scores: k logits
"""
num_lvl = len(cfg.FPN.ANCHOR_STRIDES)
assert len(multilevel_anchors) == num_lvl
assert len(multilevel_label_logits) == num_lvl
assert len(multilevel_box_logits) == num_lvl
all_boxes = []
all_scores = []
for lvl in range(num_lvl):
with tf.name_scope('FPNProposal_Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl]
pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl])
proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(multilevel_label_logits[lvl], [-1]),
image_shape2d, pre_nms_topk)
all_boxes.append(proposal_boxes)
all_scores.append(proposal_scores)
proposal_boxes = tf.concat(all_boxes, axis=0) # nx4
proposal_scores = tf.concat(all_scores, axis=0) # n
proposal_topk = tf.minimum(tf.size(proposal_scores), post_nms_topk)
proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False)
proposal_boxes = tf.gather(proposal_boxes, topk_indices)
return proposal_boxes, proposal_scores
...@@ -35,7 +35,7 @@ from model import ( ...@@ -35,7 +35,7 @@ from model import (
generate_rpn_proposals, sample_fast_rcnn_targets, generate_rpn_proposals, sample_fast_rcnn_targets,
fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions, fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_upXconv_head, maskrcnn_loss, maskrcnn_upXconv_head, maskrcnn_loss,
fpn_model, multilevel_roi_align) fpn_model, multilevel_roi_align, multilevel_rpn_losses, generate_fpn_proposals)
from model_box import ( from model_box import (
clip_boxes, decode_bbox_target, encode_bbox_target, clip_boxes, decode_bbox_target, encode_bbox_target,
crop_and_resize, roi_align, RPNAnchors) crop_and_resize, roi_align, RPNAnchors)
...@@ -163,7 +163,7 @@ class ResNetC4Model(DetectionModel): ...@@ -163,7 +163,7 @@ class ResNetC4Model(DetectionModel):
image = self.preprocess(image) # 1CHW image = self.preprocess(image) # 1CHW
featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3]) featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, cfg.RPN.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR)
anchors = RPNAnchors(get_all_anchors(), anchor_labels, anchor_boxes) anchors = RPNAnchors(get_all_anchors(), anchor_labels, anchor_boxes)
anchors = anchors.narrow_to(featuremap) anchors = anchors.narrow_to(featuremap)
...@@ -309,33 +309,15 @@ class ResNetFPNModel(DetectionModel): ...@@ -309,33 +309,15 @@ class ResNetFPNModel(DetectionModel):
self.slice_feature_and_anchors(image_shape2d, p23456, multilevel_anchors) self.slice_feature_and_anchors(image_shape2d, p23456, multilevel_anchors)
# Multi-Level RPN Proposals # Multi-Level RPN Proposals
multilevel_proposals = [] rpn_outputs = [rpn_head('rpn', pi, cfg.FPN.NUM_CHANNEL, len(cfg.RPN.ANCHOR_RATIOS))
rpn_loss_collection = [] for pi in p23456]
for lvl in range(num_fpn_level): multilevel_label_logits = [k[0] for k in rpn_outputs]
rpn_label_logits, rpn_box_logits = rpn_head( multilevel_box_logits = [k[1] for k in rpn_outputs]
'rpn', p23456[lvl], cfg.FPN.NUM_CHANNEL, len(cfg.RPN.ANCHOR_RATIOS))
with tf.name_scope('FPN_lvl{}'.format(lvl + 2)): fpn_nms_topk = cfg.RPN.TRAIN_FPN_NMS_TOPK if is_training else cfg.RPN.TEST_FPN_NMS_TOPK
anchors = multilevel_anchors[lvl] proposal_boxes, proposal_scores = generate_fpn_proposals(
pred_boxes_decoded = anchors.decode_logits(rpn_box_logits) multilevel_anchors, multilevel_label_logits, multilevel_box_logits,
proposal_boxes, proposal_scores = generate_rpn_proposals( image_shape2d, fpn_nms_topk, fpn_nms_topk)
tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]),
image_shape2d,
cfg.RPN.TRAIN_FPN_NMS_TOPK if is_training else cfg.RPN.TEST_FPN_NMS_TOPK)
multilevel_proposals.append((proposal_boxes, proposal_scores))
if is_training:
label_loss, box_loss = rpn_losses(
anchors.gt_labels, anchors.encoded_gt_boxes(),
rpn_label_logits, rpn_box_logits)
rpn_loss_collection.extend([label_loss, box_loss])
# Merge proposals from multi levels, pick top K
proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4
proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n
proposal_topk = tf.minimum(tf.size(proposal_scores),
cfg.RPN.TRAIN_FPN_NMS_TOPK if is_training else cfg.RPN.TEST_FPN_NMS_TOPK)
proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False)
proposal_boxes = tf.gather(proposal_boxes, topk_indices)
if is_training: if is_training:
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
...@@ -351,11 +333,9 @@ class ResNetFPNModel(DetectionModel): ...@@ -351,11 +333,9 @@ class ResNetFPNModel(DetectionModel):
'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS) 'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS)
if is_training: if is_training:
# rpn loss is already defined above # rpn loss:
with tf.name_scope('rpn_losses'): rpn_label_loss, rpn_box_loss = multilevel_rpn_losses(
rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2], name='label_loss') multilevel_anchors, multilevel_label_logits, multilevel_box_logits)
rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2], name='box_loss')
add_moving_summary(rpn_total_box_loss, rpn_total_label_loss)
# fastrcnn loss: # fastrcnn loss:
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
...@@ -390,7 +370,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -390,7 +370,7 @@ class ResNetFPNModel(DetectionModel):
'(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W', '(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost') l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
total_cost = tf.add_n(rpn_loss_collection + [ total_cost = tf.add_n([rpn_label_loss, rpn_box_loss,
fastrcnn_label_loss, fastrcnn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss,
mrcnn_loss, wd_cost], 'total_cost') mrcnn_loss, wd_cost], 'total_cost')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment