Commit ed32de25 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] Joint FPN proposal

parent 7315d2cc
...@@ -115,9 +115,10 @@ _C.RPN.TRAIN_PRE_NMS_TOPK = 12000 ...@@ -115,9 +115,10 @@ _C.RPN.TRAIN_PRE_NMS_TOPK = 12000
_C.RPN.TRAIN_POST_NMS_TOPK = 2000 _C.RPN.TRAIN_POST_NMS_TOPK = 2000
_C.RPN.TEST_PRE_NMS_TOPK = 6000 _C.RPN.TEST_PRE_NMS_TOPK = 6000
_C.RPN.TEST_POST_NMS_TOPK = 1000 # if you encounter OOM in inference, set this to a smaller number _C.RPN.TEST_POST_NMS_TOPK = 1000 # if you encounter OOM in inference, set this to a smaller number
# for FPN, pre/post are (for now) the same # for FPN, #proposals per-level and #proposals after merging are (for now) the same
_C.RPN.TRAIN_FPN_NMS_TOPK = 2000 # if FPN.PROPOSAL_MODE = 'Joint', these options have no effect
_C.RPN.TEST_FPN_NMS_TOPK = 1000 _C.RPN.TRAIN_PER_LEVEL_NMS_TOPK = 2000
_C.RPN.TEST_PER_LEVEL_NMS_TOPK = 1000
# fastrcnn training --------------------- # fastrcnn training ---------------------
_C.FRCNN.BATCH_PER_IM = 512 _C.FRCNN.BATCH_PER_IM = 512
...@@ -127,6 +128,7 @@ _C.FRCNN.FG_RATIO = 0.25 # fg ratio in a ROI batch ...@@ -127,6 +128,7 @@ _C.FRCNN.FG_RATIO = 0.25 # fg ratio in a ROI batch
# FPN ------------------------- # FPN -------------------------
_C.FPN.ANCHOR_STRIDES = (4, 8, 16, 32, 64) # strides for each FPN level. Must be the same length as ANCHOR_SIZES _C.FPN.ANCHOR_STRIDES = (4, 8, 16, 32, 64) # strides for each FPN level. Must be the same length as ANCHOR_SIZES
_C.FPN.PROPOSAL_MODE = 'Level' # 'Level', 'Joint'
_C.FPN.NUM_CHANNEL = 256 _C.FPN.NUM_CHANNEL = 256
# conv head and fc head are only used in FPN. # conv head and fc head are only used in FPN.
# For C4 models, the head is C5 # For C4 models, the head is C5
...@@ -162,6 +164,7 @@ def finalize_configs(is_training): ...@@ -162,6 +164,7 @@ def finalize_configs(is_training):
if _C.MODE_FPN: if _C.MODE_FPN:
size_mult = _C.FPN.RESOLUTION_REQUIREMENT * 1. size_mult = _C.FPN.RESOLUTION_REQUIREMENT * 1.
_C.PREPROC.MAX_SIZE = np.ceil(_C.PREPROC.MAX_SIZE / size_mult) * size_mult _C.PREPROC.MAX_SIZE = np.ceil(_C.PREPROC.MAX_SIZE / size_mult) * size_mult
assert _C.FPN.PROPOSAL_MODE in ['Level', 'Joint']
if is_training: if is_training:
os.environ['TF_AUTOTUNE_THRESHOLD'] = '1' os.environ['TF_AUTOTUNE_THRESHOLD'] = '1'
......
...@@ -6,6 +6,7 @@ import itertools ...@@ -6,6 +6,7 @@ import itertools
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.tower import get_current_tower_context
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import ( from tensorpack.models import (
Conv2D, layer_register, FixedUnPooling, MaxPooling) Conv2D, layer_register, FixedUnPooling, MaxPooling)
...@@ -145,7 +146,7 @@ def multilevel_rpn_losses( ...@@ -145,7 +146,7 @@ def multilevel_rpn_losses(
label_loss, box_loss = rpn_losses( label_loss, box_loss = rpn_losses(
anchors.gt_labels, anchors.encoded_gt_boxes(), anchors.gt_labels, anchors.encoded_gt_boxes(),
multilevel_label_logits[lvl], multilevel_box_logits[lvl], multilevel_label_logits[lvl], multilevel_box_logits[lvl],
name_scope='Level{}'.format(lvl + 2)) name_scope='level{}'.format(lvl + 2))
losses.extend([label_loss, box_loss]) losses.extend([label_loss, box_loss])
total_label_loss = tf.add_n(losses[::2], name='label_loss') total_label_loss = tf.add_n(losses[::2], name='label_loss')
...@@ -155,8 +156,8 @@ def multilevel_rpn_losses( ...@@ -155,8 +156,8 @@ def multilevel_rpn_losses(
def generate_fpn_proposals( def generate_fpn_proposals(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits, multilevel_anchors, multilevel_label_logits,
image_shape2d, pre_nms_topk, post_nms_topk): multilevel_box_logits, image_shape2d):
""" """
Args: Args:
multilevel_anchors: #lvl RPNAnchors multilevel_anchors: #lvl RPNAnchors
...@@ -172,8 +173,11 @@ def generate_fpn_proposals( ...@@ -172,8 +173,11 @@ def generate_fpn_proposals(
assert len(multilevel_label_logits) == num_lvl assert len(multilevel_label_logits) == num_lvl
assert len(multilevel_box_logits) == num_lvl assert len(multilevel_box_logits) == num_lvl
ctx = get_current_tower_context()
all_boxes = [] all_boxes = []
all_scores = [] all_scores = []
if cfg.FPN.PROPOSAL_MODE == 'Level':
fpn_nms_topk = cfg.RPN.TRAIN_PER_LEVEL_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PER_LEVEL_NMS_TOPK
for lvl in range(num_lvl): for lvl in range(num_lvl):
with tf.name_scope('FPNProposal_Lvl{}'.format(lvl + 2)): with tf.name_scope('FPNProposal_Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl] anchors = multilevel_anchors[lvl]
...@@ -182,13 +186,27 @@ def generate_fpn_proposals( ...@@ -182,13 +186,27 @@ def generate_fpn_proposals(
proposal_boxes, proposal_scores = generate_rpn_proposals( proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(multilevel_label_logits[lvl], [-1]), tf.reshape(multilevel_label_logits[lvl], [-1]),
image_shape2d, pre_nms_topk) image_shape2d, fpn_nms_topk)
all_boxes.append(proposal_boxes) all_boxes.append(proposal_boxes)
all_scores.append(proposal_scores) all_scores.append(proposal_scores)
proposal_boxes = tf.concat(all_boxes, axis=0) # nx4 proposal_boxes = tf.concat(all_boxes, axis=0) # nx4
proposal_scores = tf.concat(all_scores, axis=0) # n proposal_scores = tf.concat(all_scores, axis=0) # n
proposal_topk = tf.minimum(tf.size(proposal_scores), post_nms_topk) proposal_topk = tf.minimum(tf.size(proposal_scores), fpn_nms_topk)
proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False) proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False)
proposal_boxes = tf.gather(proposal_boxes, topk_indices) proposal_boxes = tf.gather(proposal_boxes, topk_indices)
else:
for lvl in range(num_lvl):
with tf.name_scope('FPNProposal_Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl]
pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl])
all_boxes.append(tf.reshape(pred_boxes_decoded, [-1, 4]))
all_scores.append(tf.reshape(multilevel_label_logits[lvl], [-1]))
all_boxes = tf.concat(all_boxes, axis=0)
all_scores = tf.concat(all_scores, axis=0)
proposal_boxes, proposal_scores = generate_rpn_proposals(
all_boxes, all_scores, image_shape2d,
cfg.RPN.TRAIN_PRE_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PRE_NMS_TOPK,
cfg.RPN.TRAIN_POST_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_POST_NMS_TOPK)
return proposal_boxes, proposal_scores return proposal_boxes, proposal_scores
...@@ -318,10 +318,9 @@ class ResNetFPNModel(DetectionModel): ...@@ -318,10 +318,9 @@ class ResNetFPNModel(DetectionModel):
multilevel_label_logits = [k[0] for k in rpn_outputs] multilevel_label_logits = [k[0] for k in rpn_outputs]
multilevel_box_logits = [k[1] for k in rpn_outputs] multilevel_box_logits = [k[1] for k in rpn_outputs]
fpn_nms_topk = cfg.RPN.TRAIN_FPN_NMS_TOPK if is_training else cfg.RPN.TEST_FPN_NMS_TOPK
proposal_boxes, proposal_scores = generate_fpn_proposals( proposal_boxes, proposal_scores = generate_fpn_proposals(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits, multilevel_anchors, multilevel_label_logits,
image_shape2d, fpn_nms_topk, fpn_nms_topk) multilevel_box_logits, image_shape2d)
if is_training: if is_training:
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
...@@ -356,7 +355,8 @@ class ResNetFPNModel(DetectionModel): ...@@ -356,7 +355,8 @@ class ResNetFPNModel(DetectionModel):
# maskrcnn loss # maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
roi_feature_maskrcnn = multilevel_roi_align( roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], fg_sampled_boxes, 14) p23456[:4], fg_sampled_boxes, 14,
name_scope='multilevel_roi_align_mask')
mask_logits = maskrcnn_upXconv_head( mask_logits = maskrcnn_upXconv_head(
'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY, 4) # #fg x #cat x 28 x 28 'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY, 4) # #fg x #cat x 28 x 28
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment