Commit cc3de54d authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] small clean-ups

parent 1ae8b540
...@@ -180,7 +180,6 @@ _C.MRCNN.HEAD_DIM = 256 ...@@ -180,7 +180,6 @@ _C.MRCNN.HEAD_DIM = 256
# Cascade-RCNN, only available in FPN mode # Cascade-RCNN, only available in FPN mode
_C.FPN.CASCADE = False _C.FPN.CASCADE = False
_C.CASCADE.NUM_STAGES = 3
_C.CASCADE.IOUS = [0.5, 0.6, 0.7] _C.CASCADE.IOUS = [0.5, 0.6, 0.7]
_C.CASCADE.BBOX_REG_WEIGHTS = [[10., 10., 5., 5.], [20., 20., 10., 10.], [30., 30., 15., 15.]] _C.CASCADE.BBOX_REG_WEIGHTS = [[10., 10., 5., 5.], [20., 20., 10., 10.], [30., 30., 15., 15.]]
...@@ -220,11 +219,9 @@ def finalize_configs(is_training): ...@@ -220,11 +219,9 @@ def finalize_configs(is_training):
assert _C.FPN.NORM in ['None', 'GN'] assert _C.FPN.NORM in ['None', 'GN']
if _C.FPN.CASCADE: if _C.FPN.CASCADE:
num_cascade = _C.CASCADE.NUM_STAGES
# the first threshold is the proposal sampling threshold # the first threshold is the proposal sampling threshold
assert len(_C.CASCADE.IOUS) == num_cascade
assert _C.CASCADE.IOUS[0] == _C.FRCNN.FG_THRESH assert _C.CASCADE.IOUS[0] == _C.FRCNN.FG_THRESH
assert len(_C.CASCADE.BBOX_REG_WEIGHTS) == num_cascade assert len(_C.CASCADE.BBOX_REG_WEIGHTS) == len(_C.CASCADE.IOUS)
if is_training: if is_training:
train_scales = _C.PREPROC.TRAIN_SHORT_EDGE_SIZE train_scales = _C.PREPROC.TRAIN_SHORT_EDGE_SIZE
......
...@@ -10,18 +10,21 @@ from config import config as cfg ...@@ -10,18 +10,21 @@ from config import config as cfg
class CascadeRCNNHead(object): class CascadeRCNNHead(object):
def __init__(self, proposals, def __init__(self, proposals,
roi_func, fastrcnn_head_func, image_shape2d, num_classes): roi_func, fastrcnn_head_func, gt_targets, image_shape2d, num_classes):
""" """
Args: Args:
proposals: BoxProposals proposals: BoxProposals
roi_func (boxes -> features): a function to crop features with rois roi_func (boxes -> features): a function to crop features with rois
fastrcnn_head_func (features -> features): the fastrcnn head to apply on the cropped features fastrcnn_head_func (features -> features): the fastrcnn head to apply on the cropped features
gt_targets (gt_boxes, gt_labels):
""" """
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.gt_boxes, self.gt_labels = gt_targets
del self.gt_targets
self.num_cascade_stages = cfg.CASCADE.NUM_STAGES self.num_cascade_stages = len(cfg.CASCADE.IOUS)
self.is_training = get_current_tower_context().is_training self.is_training = get_current_tower_context().is_training
if self.is_training: if self.is_training:
...@@ -29,8 +32,6 @@ class CascadeRCNNHead(object): ...@@ -29,8 +32,6 @@ class CascadeRCNNHead(object):
def scale_gradient(x): def scale_gradient(x):
return x, lambda dy: dy * (1.0 / self.num_cascade_stages) return x, lambda dy: dy * (1.0 / self.num_cascade_stages)
self.scale_gradient = scale_gradient self.scale_gradient = scale_gradient
self.gt_boxes = proposals.gt_boxes
self.gt_labels = proposals.gt_labels
else: else:
self.scale_gradient = tf.identity self.scale_gradient = tf.identity
...@@ -66,7 +67,7 @@ class CascadeRCNNHead(object): ...@@ -66,7 +67,7 @@ class CascadeRCNNHead(object):
head_feature = self.fastrcnn_head_func('head', pooled_feature) head_feature = self.fastrcnn_head_func('head', pooled_feature)
label_logits, box_logits = fastrcnn_outputs( label_logits, box_logits = fastrcnn_outputs(
'outputs', head_feature, self.num_classes, class_agnostic_regression=True) 'outputs', head_feature, self.num_classes, class_agnostic_regression=True)
head = FastRCNNHead(proposals, box_logits, label_logits, reg_weights) head = FastRCNNHead(proposals, box_logits, label_logits, self.gt_boxes, reg_weights)
refined_boxes = head.decoded_output_boxes_class_agnostic() refined_boxes = head.decoded_output_boxes_class_agnostic()
refined_boxes = clip_boxes(refined_boxes, self.image_shape2d) refined_boxes = clip_boxes(refined_boxes, self.image_shape2d)
...@@ -88,8 +89,7 @@ class CascadeRCNNHead(object): ...@@ -88,8 +89,7 @@ class CascadeRCNNHead(object):
fg_mask = max_iou_per_box >= iou_threshold fg_mask = max_iou_per_box >= iou_threshold
fg_inds_wrt_gt = tf.boolean_mask(best_iou_ind, fg_mask) fg_inds_wrt_gt = tf.boolean_mask(best_iou_ind, fg_mask)
labels_per_box = tf.stop_gradient(labels_per_box * tf.to_int64(fg_mask)) labels_per_box = tf.stop_gradient(labels_per_box * tf.to_int64(fg_mask))
return BoxProposals( return BoxProposals(boxes, labels_per_box, fg_inds_wrt_gt)
boxes, labels_per_box, fg_inds_wrt_gt, self.gt_boxes, self.gt_labels)
else: else:
return BoxProposals(boxes) return BoxProposals(boxes)
......
...@@ -99,8 +99,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -99,8 +99,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
return BoxProposals( return BoxProposals(
tf.stop_gradient(ret_boxes, name='sampled_proposal_boxes'), tf.stop_gradient(ret_boxes, name='sampled_proposal_boxes'),
tf.stop_gradient(ret_labels, name='sampled_labels'), tf.stop_gradient(ret_labels, name='sampled_labels'),
tf.stop_gradient(fg_inds_wrt_gt), tf.stop_gradient(fg_inds_wrt_gt))
gt_boxes, gt_labels)
@layer_register(log_shape=True) @layer_register(log_shape=True)
...@@ -302,16 +301,12 @@ class BoxProposals(object): ...@@ -302,16 +301,12 @@ class BoxProposals(object):
""" """
A structure to manage box proposals and their relations with ground truth. A structure to manage box proposals and their relations with ground truth.
""" """
def __init__(self, boxes, def __init__(self, boxes, labels=None, fg_inds_wrt_gt=None):
labels=None, fg_inds_wrt_gt=None,
gt_boxes=None, gt_labels=None):
""" """
Args: Args:
boxes: Nx4 boxes: Nx4
labels: N, each in [0, #class), the true label for each input box labels: N, each in [0, #class), the true label for each input box
fg_inds_wrt_gt: #fg, each in [0, M) fg_inds_wrt_gt: #fg, each in [0, M)
gt_boxes: Mx4
gt_labels: M
The last four arguments could be None when not training. The last four arguments could be None when not training.
""" """
...@@ -334,22 +329,18 @@ class BoxProposals(object): ...@@ -334,22 +329,18 @@ class BoxProposals(object):
""" Returns: #fg""" """ Returns: #fg"""
return tf.gather(self.labels, self.fg_inds(), name='fg_labels') return tf.gather(self.labels, self.fg_inds(), name='fg_labels')
@memoized_method
def matched_gt_boxes(self):
""" Returns: #fg x 4"""
return tf.gather(self.gt_boxes, self.fg_inds_wrt_gt)
class FastRCNNHead(object): class FastRCNNHead(object):
""" """
A class to process & decode inputs/outputs of a fastrcnn classification+regression head. A class to process & decode inputs/outputs of a fastrcnn classification+regression head.
""" """
def __init__(self, proposals, box_logits, label_logits, bbox_regression_weights): def __init__(self, proposals, box_logits, label_logits, gt_boxes, bbox_regression_weights):
""" """
Args: Args:
proposals: BoxProposals proposals: BoxProposals
box_logits: Nx#classx4 or Nx1x4, the output of the head box_logits: Nx#classx4 or Nx1x4, the output of the head
label_logits: Nx#class, the output of the head label_logits: Nx#class, the output of the head
gt_boxes: Mx4
bbox_regression_weights: a 4 element tensor bbox_regression_weights: a 4 element tensor
""" """
for k, v in locals().items(): for k, v in locals().items():
...@@ -365,7 +356,7 @@ class FastRCNNHead(object): ...@@ -365,7 +356,7 @@ class FastRCNNHead(object):
@memoized_method @memoized_method
def losses(self): def losses(self):
encoded_fg_gt_boxes = encode_bbox_target( encoded_fg_gt_boxes = encode_bbox_target(
self.proposals.matched_gt_boxes(), tf.gather(self.gt_boxes, self.proposals.fg_inds_wrt_gt),
self.proposals.fg_boxes()) * self.bbox_regression_weights self.proposals.fg_boxes()) * self.bbox_regression_weights
return fastrcnn_losses( return fastrcnn_losses(
self.proposals.labels, self.label_logits, self.proposals.labels, self.label_logits,
......
...@@ -96,12 +96,10 @@ class DetectionModel(ModelDesc): ...@@ -96,12 +96,10 @@ class DetectionModel(ModelDesc):
image = self.preprocess(inputs['image']) # 1CHW image = self.preprocess(inputs['image']) # 1CHW
features = self.backbone(image) features = self.backbone(image)
proposals, rpn_losses = self.rpn(image, features, inputs) # inputs? anchor_inputs = {k: v for k, v in inputs.items() if k.startswith('anchor_')}
proposals, rpn_losses = self.rpn(image, features, anchor_inputs) # inputs?
targets = [inputs['gt_boxes'], inputs['gt_labels']]
if 'gt_masks' in inputs:
targets.append(inputs['gt_masks'])
targets = [inputs[k] for k in ['gt_boxes', 'gt_labels', 'gt_masks'] if k in inputs]
head_losses = self.roi_heads(image, features, proposals, targets) head_losses = self.roi_heads(image, features, proposals, targets)
if self.training: if self.training:
...@@ -299,13 +297,14 @@ class ResNetFPNModel(DetectionModel): ...@@ -299,13 +297,14 @@ class ResNetFPNModel(DetectionModel):
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs( fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CLASS) 'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits, fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32)) gt_boxes, tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
else: else:
def roi_func(boxes): def roi_func(boxes):
return multilevel_roi_align(features[:4], boxes, 7) return multilevel_roi_align(features[:4], boxes, 7)
fastrcnn_head = CascadeRCNNHead( fastrcnn_head = CascadeRCNNHead(
proposals, roi_func, fastrcnn_head_func, image_shape2d, cfg.DATA.NUM_CLASS) proposals, roi_func, fastrcnn_head_func,
(gt_boxes, gt_labels), image_shape2d, cfg.DATA.NUM_CLASS)
if self.training: if self.training:
all_losses = fastrcnn_head.losses() all_losses = fastrcnn_head.losses()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment