Commit 443cb84d authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] add Cascade-RCNN

parent f7dfb86a
......@@ -6,7 +6,7 @@ This is a minimal implementation that simply contains these files:
+ common.py: common data preparation utilities
+ basemodel.py: implement backbones
+ model_box.py: implement box-related symbolic functions
+ model_{fpn,rpn,mrcnn,frcnn}.py: implement FPN,RPN,Mask-/Fast-RCNN models.
+ model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast-/Mask-/Cascade-RCNN models.
+ train.py: main training script
+ utils/: third-party helper functions
+ eval.py: evaluation utilities
......@@ -47,7 +47,7 @@ Model:
Speed:
1. The training will start very slow due to convolution warmup, until about 10k
steps to reach a maximum speed.
steps to reach a maximum speed.
You can disable warmup by `export TF_CUDNN_USE_AUTOTUNE=0`, which makes the
training faster at the beginning, but perhaps not in the end.
......@@ -56,7 +56,7 @@ Speed:
1. This implementation is about 10% slower than detectron,
probably due to the lack of specialized ops (e.g. AffineChannel, ROIAlign) in TensorFlow.
It's certainly faster than other TF implementation.
1. The code should have around 70% GPU utilization on V100s, and 85%~90% scaling
efficiency from 1 V100 to 8 V100s.
......
# Faster-RCNN / Mask-RCNN on COCO
This example provides a minimal (<2k lines) and faithful implementation of the following papers:
This example provides a minimal (2k lines) and faithful implementation of the following papers:
+ [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497)
+ [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
+ [Mask R-CNN](https://arxiv.org/abs/1703.06870)
+ [Cascade R-CNN: Delving into High Quality Object Detection](https://arxiv.org/abs/1712.00726)
with the support of:
+ Multi-GPU / distributed training
......
......@@ -163,6 +163,12 @@ _C.FPN.MRCNN_HEAD_FUNC = 'maskrcnn_up4conv_head' # choices: maskrcnn_up4conv_{
# Mask-RCNN
_C.MRCNN.HEAD_DIM = 256
# Cascade-RCNN, only available in FPN mode
_C.FPN.CASCADE = False
_C.CASCADE.NUM_STAGES = 3
_C.CASCADE.IOUS = [0.5, 0.6, 0.7]
_C.CASCADE.BBOX_REG_WEIGHTS = [[10., 10., 5., 5.], [20., 20., 10., 10.], [30., 30., 15., 15.]]
# testing -----------------------
_C.TEST.FRCNN_NMS_THRESH = 0.5
......@@ -198,6 +204,13 @@ def finalize_configs(is_training):
assert _C.FPN.MRCNN_HEAD_FUNC.endswith('_head')
assert _C.FPN.NORM in ['None', 'GN']
if _C.FPN.CASCADE:
num_cascade = _C.CASCADE.NUM_STAGES
# the first threshold is the proposal sampling threshold
assert len(_C.CASCADE.IOUS) == num_cascade
assert _C.CASCADE.IOUS[0] == _C.FRCNN.FG_THRESH
assert len(_C.CASCADE.BBOX_REG_WEIGHTS) == num_cascade
if is_training:
os.environ['TF_AUTOTUNE_THRESHOLD'] = '1'
assert _C.TRAINER in ['horovod', 'replicated'], _C.TRAINER
......
import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context
from utils.box_ops import pairwise_iou
from model_box import clip_boxes
from model_frcnn import FastRCNNHead, BoxProposals, fastrcnn_outputs
from config import config as cfg
class CascadeRCNNHead(object):
def __init__(self, proposals,
roi_func, fastrcnn_head_func, image_shape2d, num_classes):
"""
Args:
proposals: BoxProposals
roi_func (boxes -> features): a function to crop features with rois
fastrcnn_head_func (features -> features): the fastrcnn head to apply on the cropped features
"""
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.num_cascade_stages = cfg.CASCADE.NUM_STAGES
self.is_training = get_current_tower_context().is_training
if self.is_training:
@tf.custom_gradient
def scale_gradient(x):
return x, lambda dy: dy * (1.0 / self.num_cascade_stages)
self.scale_gradient = scale_gradient
self.gt_boxes = proposals.gt_boxes
self.gt_labels = proposals.gt_labels
else:
self.scale_gradient = tf.identity
ious = cfg.CASCADE.IOUS
# It's unclear how to do >3 stages, so it does not make sense to implement them
assert self.num_cascade_stages == 3, "Only 3-stage cascade was implemented!"
with tf.variable_scope('cascade_rcnn_stage1'):
H1, B1 = self.run_head(self.proposals, 0)
with tf.variable_scope('cascade_rcnn_stage2'):
B1_proposal = self.match_box_with_gt(B1, ious[1])
H2, B2 = self.run_head(B1_proposal, 1)
with tf.variable_scope('cascade_rcnn_stage3'):
B2_proposal = self.match_box_with_gt(B2, ious[2])
H3, B3 = self.run_head(B2_proposal, 2)
self._cascade_boxes = [B1, B2, B3]
self._heads = [H1, H2, H3]
def run_head(self, proposals, stage):
"""
Args:
proposals: BoxProposals
stage: 0, 1, 2
Returns:
FastRCNNHead
Nx4, updated boxes
"""
reg_weights = tf.constant(cfg.CASCADE.BBOX_REG_WEIGHTS[stage], dtype=tf.float32)
pooled_feature = self.roi_func(proposals.boxes) # N,C,S,S
pooled_feature = self.scale_gradient(pooled_feature)
head_feature = self.fastrcnn_head_func('head', pooled_feature)
label_logits, box_logits = fastrcnn_outputs(
'outputs', head_feature, self.num_classes, class_agnostic_regression=True)
head = FastRCNNHead(proposals, box_logits, label_logits, reg_weights)
refined_boxes = head.decoded_output_boxes_class_agnostic()
refined_boxes = clip_boxes(refined_boxes, self.image_shape2d)
return head, tf.stop_gradient(refined_boxes, name='output_boxes')
def match_box_with_gt(self, boxes, iou_threshold):
"""
Args:
boxes: Nx4
Returns:
BoxProposals
"""
if self.is_training:
with tf.name_scope('match_box_with_gt_{}'.format(iou_threshold)):
iou = pairwise_iou(boxes, self.gt_boxes) # NxM
max_iou_per_box = tf.reduce_max(iou, axis=1) # N
best_iou_ind = tf.argmax(iou, axis=1) # N
labels_per_box = tf.gather(self.gt_labels, best_iou_ind)
fg_mask = max_iou_per_box >= iou_threshold
fg_inds_wrt_gt = tf.boolean_mask(best_iou_ind, fg_mask)
labels_per_box = tf.stop_gradient(labels_per_box * tf.to_int64(fg_mask))
return BoxProposals(
boxes, labels_per_box, fg_inds_wrt_gt, self.gt_boxes, self.gt_labels)
else:
return BoxProposals(boxes)
def losses(self):
ret = []
for idx, head in enumerate(self._heads):
with tf.name_scope('cascade_loss_stage{}'.format(idx + 1)):
ret.extend(head.losses())
return ret
def decoded_output_boxes(self):
"""
Returns:
Nx#classx4
"""
ret = self._cascade_boxes[-1]
ret = tf.expand_dims(ret, 1) # class-agnostic
return tf.tile(ret, [1, self.num_classes, 1])
def output_scores(self, name=None):
"""
Returns:
Nx#class
"""
scores = [head.output_scores('cascade_scores_stage{}'.format(idx + 1))
for idx, head in enumerate(self._heads)]
return tf.multiply(tf.add_n(scores), (1.0 / self.num_cascade_stages), name=name)
......@@ -41,6 +41,7 @@ from model_rpn import rpn_head, rpn_losses, generate_rpn_proposals
from model_fpn import (
fpn_model, multilevel_roi_align,
multilevel_rpn_losses, generate_fpn_proposals)
from model_cascade import CascadeRCNNHead
from model_box import (
clip_boxes, crop_and_resize, roi_align, RPNAnchors)
......@@ -258,14 +259,21 @@ class ResNetFPNModel(DetectionModel):
else:
proposals = BoxProposals(proposal_boxes)
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], proposals.boxes, 7)
fastrcnn_head_func = getattr(model_frcnn, cfg.FPN.FRCNN_HEAD_FUNC)
head_feature = fastrcnn_head_func('fastrcnn', roi_feature_fastrcnn)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
if not cfg.FPN.CASCADE:
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], proposals.boxes, 7)
head_feature = fastrcnn_head_func('fastrcnn', roi_feature_fastrcnn)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
else:
def roi_func(boxes):
return multilevel_roi_align(p23456[:4], boxes, 7)
fastrcnn_head = CascadeRCNNHead(
proposals, roi_func, fastrcnn_head_func, image_shape2d, cfg.DATA.NUM_CLASS)
if is_training:
all_losses = []
......
......@@ -34,8 +34,10 @@ if __name__ == '__main__':
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
assert len(set(var_to_dump)) == len(var_to_dump), "TRAINABLE and MODEL variables have duplication!"
globvarname = [k.name for k in tf.global_variables()]
if len(set(var_to_dump)) != len(var_to_dump):
print("TRAINABLE and MODEL variables have duplication!")
var_to_dump = list(set(var_to_dump))
globvarname = set([k.name for k in tf.global_variables()])
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
for name in var_to_dump:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment