Commit cc3de54d authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] small clean-ups

parent 1ae8b540
......@@ -180,7 +180,6 @@ _C.MRCNN.HEAD_DIM = 256
# Cascade-RCNN, only available in FPN mode
_C.FPN.CASCADE = False
_C.CASCADE.NUM_STAGES = 3
_C.CASCADE.IOUS = [0.5, 0.6, 0.7]
_C.CASCADE.BBOX_REG_WEIGHTS = [[10., 10., 5., 5.], [20., 20., 10., 10.], [30., 30., 15., 15.]]
......@@ -220,11 +219,9 @@ def finalize_configs(is_training):
assert _C.FPN.NORM in ['None', 'GN']
if _C.FPN.CASCADE:
num_cascade = _C.CASCADE.NUM_STAGES
# the first threshold is the proposal sampling threshold
assert len(_C.CASCADE.IOUS) == num_cascade
assert _C.CASCADE.IOUS[0] == _C.FRCNN.FG_THRESH
assert len(_C.CASCADE.BBOX_REG_WEIGHTS) == num_cascade
assert len(_C.CASCADE.BBOX_REG_WEIGHTS) == len(_C.CASCADE.IOUS)
if is_training:
train_scales = _C.PREPROC.TRAIN_SHORT_EDGE_SIZE
......
......@@ -10,18 +10,21 @@ from config import config as cfg
class CascadeRCNNHead(object):
def __init__(self, proposals,
roi_func, fastrcnn_head_func, image_shape2d, num_classes):
roi_func, fastrcnn_head_func, gt_targets, image_shape2d, num_classes):
"""
Args:
proposals: BoxProposals
roi_func (boxes -> features): a function to crop features with rois
fastrcnn_head_func (features -> features): the fastrcnn head to apply on the cropped features
gt_targets (gt_boxes, gt_labels):
"""
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.gt_boxes, self.gt_labels = gt_targets
del self.gt_targets
self.num_cascade_stages = cfg.CASCADE.NUM_STAGES
self.num_cascade_stages = len(cfg.CASCADE.IOUS)
self.is_training = get_current_tower_context().is_training
if self.is_training:
......@@ -29,8 +32,6 @@ class CascadeRCNNHead(object):
def scale_gradient(x):
return x, lambda dy: dy * (1.0 / self.num_cascade_stages)
self.scale_gradient = scale_gradient
self.gt_boxes = proposals.gt_boxes
self.gt_labels = proposals.gt_labels
else:
self.scale_gradient = tf.identity
......@@ -66,7 +67,7 @@ class CascadeRCNNHead(object):
head_feature = self.fastrcnn_head_func('head', pooled_feature)
label_logits, box_logits = fastrcnn_outputs(
'outputs', head_feature, self.num_classes, class_agnostic_regression=True)
head = FastRCNNHead(proposals, box_logits, label_logits, reg_weights)
head = FastRCNNHead(proposals, box_logits, label_logits, self.gt_boxes, reg_weights)
refined_boxes = head.decoded_output_boxes_class_agnostic()
refined_boxes = clip_boxes(refined_boxes, self.image_shape2d)
......@@ -88,8 +89,7 @@ class CascadeRCNNHead(object):
fg_mask = max_iou_per_box >= iou_threshold
fg_inds_wrt_gt = tf.boolean_mask(best_iou_ind, fg_mask)
labels_per_box = tf.stop_gradient(labels_per_box * tf.to_int64(fg_mask))
return BoxProposals(
boxes, labels_per_box, fg_inds_wrt_gt, self.gt_boxes, self.gt_labels)
return BoxProposals(boxes, labels_per_box, fg_inds_wrt_gt)
else:
return BoxProposals(boxes)
......
......@@ -99,8 +99,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
return BoxProposals(
tf.stop_gradient(ret_boxes, name='sampled_proposal_boxes'),
tf.stop_gradient(ret_labels, name='sampled_labels'),
tf.stop_gradient(fg_inds_wrt_gt),
gt_boxes, gt_labels)
tf.stop_gradient(fg_inds_wrt_gt))
@layer_register(log_shape=True)
......@@ -302,16 +301,12 @@ class BoxProposals(object):
"""
A structure to manage box proposals and their relations with ground truth.
"""
def __init__(self, boxes,
labels=None, fg_inds_wrt_gt=None,
gt_boxes=None, gt_labels=None):
def __init__(self, boxes, labels=None, fg_inds_wrt_gt=None):
"""
Args:
boxes: Nx4
labels: N, each in [0, #class), the true label for each input box
fg_inds_wrt_gt: #fg, each in [0, M)
gt_boxes: Mx4
gt_labels: M
The last four arguments could be None when not training.
"""
......@@ -334,22 +329,18 @@ class BoxProposals(object):
""" Returns: #fg"""
return tf.gather(self.labels, self.fg_inds(), name='fg_labels')
@memoized_method
def matched_gt_boxes(self):
""" Returns: #fg x 4"""
return tf.gather(self.gt_boxes, self.fg_inds_wrt_gt)
class FastRCNNHead(object):
"""
A class to process & decode inputs/outputs of a fastrcnn classification+regression head.
"""
def __init__(self, proposals, box_logits, label_logits, bbox_regression_weights):
def __init__(self, proposals, box_logits, label_logits, gt_boxes, bbox_regression_weights):
"""
Args:
proposals: BoxProposals
box_logits: Nx#classx4 or Nx1x4, the output of the head
label_logits: Nx#class, the output of the head
gt_boxes: Mx4
bbox_regression_weights: a 4 element tensor
"""
for k, v in locals().items():
......@@ -365,7 +356,7 @@ class FastRCNNHead(object):
@memoized_method
def losses(self):
encoded_fg_gt_boxes = encode_bbox_target(
self.proposals.matched_gt_boxes(),
tf.gather(self.gt_boxes, self.proposals.fg_inds_wrt_gt),
self.proposals.fg_boxes()) * self.bbox_regression_weights
return fastrcnn_losses(
self.proposals.labels, self.label_logits,
......
......@@ -96,12 +96,10 @@ class DetectionModel(ModelDesc):
image = self.preprocess(inputs['image']) # 1CHW
features = self.backbone(image)
proposals, rpn_losses = self.rpn(image, features, inputs) # inputs?
targets = [inputs['gt_boxes'], inputs['gt_labels']]
if 'gt_masks' in inputs:
targets.append(inputs['gt_masks'])
anchor_inputs = {k: v for k, v in inputs.items() if k.startswith('anchor_')}
proposals, rpn_losses = self.rpn(image, features, anchor_inputs) # inputs?
targets = [inputs[k] for k in ['gt_boxes', 'gt_labels', 'gt_masks'] if k in inputs]
head_losses = self.roi_heads(image, features, proposals, targets)
if self.training:
......@@ -299,13 +297,14 @@ class ResNetFPNModel(DetectionModel):
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
gt_boxes, tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
else:
def roi_func(boxes):
return multilevel_roi_align(features[:4], boxes, 7)
fastrcnn_head = CascadeRCNNHead(
proposals, roi_func, fastrcnn_head_func, image_shape2d, cfg.DATA.NUM_CLASS)
proposals, roi_func, fastrcnn_head_func,
(gt_boxes, gt_labels), image_shape2d, cfg.DATA.NUM_CLASS)
if self.training:
all_losses = fastrcnn_head.losses()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment