Commit 6041a1a4 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] BoxProposals struct to manage proposals; rename probs->scores

parent cf97218c
......@@ -59,7 +59,7 @@ To predict on an image (and show output in a window):
./train.py --predict input.jpg --load /path/to/model --config SAME-AS-TRAINING
```
To Evaluate the performance of a model on COCO:
To evaluate the performance of a model on COCO:
```
./train.py --evaluate output.json --load /path/to/COCO-R50C4-MaskRCNN-Standard.npz \
--config SAME-AS-TRAINING
......
......@@ -50,8 +50,9 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
gt_labels: m, int32
Returns:
A BoxProposals instance.
sampled_boxes: tx4 floatbox, the rois
sampled_labels: t int64 labels, in [0, #class-1]. Positive means foreground.
sampled_labels: t int64 labels, in [0, #class). Positive means foreground.
fg_inds_wrt_gt: #fg indices, each in range [0, m-1].
It contains the matching GT of each foreground roi.
"""
......@@ -94,9 +95,11 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
[tf.gather(gt_labels, fg_inds_wrt_gt),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0)
# stop the gradient -- they are meant to be training targets
return tf.stop_gradient(ret_boxes, name='sampled_proposal_boxes'), \
tf.stop_gradient(ret_labels, name='sampled_labels'), \
tf.stop_gradient(fg_inds_wrt_gt)
return BoxProposals(
tf.stop_gradient(ret_boxes, name='sampled_proposal_boxes'),
tf.stop_gradient(ret_labels, name='sampled_labels'),
tf.stop_gradient(fg_inds_wrt_gt),
gt_boxes, gt_labels)
@layer_register(log_shape=True)
......@@ -168,23 +171,24 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
@under_name_scope()
def fastrcnn_predictions(boxes, probs):
def fastrcnn_predictions(boxes, scores):
"""
Generate final results from predictions of all proposals.
Args:
boxes: n#classx4 floatbox in float32
probs: nx#class
scores: nx#class
Returns:
indices: Kx2. Each is (box_id, class_id)
probs: K floats
boxes: Kx4
scores: K
labels: K
"""
assert boxes.shape[1] == cfg.DATA.NUM_CLASS
assert probs.shape[1] == cfg.DATA.NUM_CLASS
assert scores.shape[1] == cfg.DATA.NUM_CLASS
boxes = tf.transpose(boxes, [1, 0, 2])[1:, :, :] # #catxnx4
boxes.set_shape([None, cfg.DATA.NUM_CATEGORY, None])
probs = tf.transpose(probs[:, 1:], [1, 0]) # #catxn
scores = tf.transpose(scores[:, 1:], [1, 0]) # #catxn
def f(X):
"""
......@@ -213,20 +217,24 @@ def fastrcnn_predictions(boxes, probs):
default_value=False)
return mask
masks = tf.map_fn(f, (probs, boxes), dtype=tf.bool,
masks = tf.map_fn(f, (scores, boxes), dtype=tf.bool,
parallel_iterations=10) # #cat x N
selected_indices = tf.where(masks) # #selection x 2, each is (cat_id, box_id)
probs = tf.boolean_mask(probs, masks)
scores = tf.boolean_mask(scores, masks)
# filter again by sorting scores
topk_probs, topk_indices = tf.nn.top_k(
probs,
tf.minimum(cfg.TEST.RESULTS_PER_IM, tf.size(probs)),
topk_scores, topk_indices = tf.nn.top_k(
scores,
tf.minimum(cfg.TEST.RESULTS_PER_IM, tf.size(scores)),
sorted=False)
filtered_selection = tf.gather(selected_indices, topk_indices)
cat_ids, box_ids = tf.unstack(filtered_selection, axis=1)
final_ids = tf.stack([box_ids, cat_ids + 1], axis=1, name='final_ids') # Kx2, each is (box_id, class_id)
return final_ids, topk_probs
final_scores = tf.identity(topk_scores, name='scores')
final_labels = tf.add(cat_ids, 1, name='labels')
final_ids = tf.stack([cat_ids, box_ids], axis=1, name='all_ids')
final_boxes = tf.gather_nd(boxes, final_ids, name='boxes')
return final_boxes, final_scores, final_labels
"""
......@@ -284,63 +292,84 @@ def fastrcnn_4conv1fc_gn_head(*args, **kwargs):
return fastrcnn_Xconv1fc_head(*args, num_convs=4, norm='GN', **kwargs)
class FastRCNNHead(object):
class BoxProposals(object):
"""
A class to process & decode inputs/outputs of a fastrcnn classification+regression head.
A structure to manage box proposals and their relation with ground truth.
"""
def __init__(self, input_boxes, box_logits, label_logits, bbox_regression_weights,
labels=None, matched_gt_boxes_per_fg=None):
def __init__(self, boxes,
labels=None, fg_inds_wrt_gt=None,
gt_boxes=None, gt_labels=None):
"""
Args:
input_boxes: Nx4, inputs to the head
box_logits: Nx#classx4 or Nx1x4, the output of the head
label_logits: Nx#class, the output of the head
bbox_regression_weights: a 4 element tensor
boxes: Nx4
labels: N, each in [0, #class), the true label for each input box
matched_gt_boxes_per_fg: #fgx4, the matching gt boxes for each fg input box
fg_inds_wrt_gt: #fg, each in [0, M)
gt_boxes: Mx4
gt_labels: M
The last two arguments could be None when not training.
The last four arguments could be None when not training.
"""
for k, v in locals().items():
if k != 'self':
if k != 'self' and v is not None:
setattr(self, k, v)
self._bbox_class_agnostic = int(box_logits.shape[1]) == 1
@memoized
def fg_inds_in_inputs(self):
def fg_inds(self):
""" Returns: #fg indices in [0, N-1] """
assert self.labels is not None
return tf.reshape(tf.where(self.labels > 0), [-1], name='fg_inds_in_inputs')
return tf.reshape(tf.where(self.labels > 0), [-1], name='fg_inds')
@memoized
def fg_input_boxes(self):
""" Returns: #fgx4 """
return tf.gather(self.input_boxes, self.fg_inds_in_inputs(), name='fg_input_boxes')
def fg_boxes(self):
""" Returns: #fg x4"""
return tf.gather(self.boxes, self.fg_inds(), name='fg_boxes')
@memoized
def fg_box_logits(self):
""" Returns: #fg x ? x 4 """
return tf.gather(self.box_logits, self.fg_inds_in_inputs(), name='fg_box_logits')
def fg_labels(self):
""" Returns: #fg"""
return tf.gather(self.labels, self.fg_inds(), name='fg_labels')
@memoized
def fg_labels(self):
""" Returns: #fg """
return tf.gather(self.labels, self.fg_inds_in_inputs(), name='fg_labels')
def matched_gt_boxes(self):
""" Returns: #fg x 4"""
return tf.gather(self.gt_boxes, self.fg_inds_wrt_gt)
class FastRCNNHead(object):
"""
A class to process & decode inputs/outputs of a fastrcnn classification+regression head.
"""
def __init__(self, proposals, box_logits, label_logits, bbox_regression_weights):
"""
Args:
proposals: BoxProposals
box_logits: Nx#classx4 or Nx1x4, the output of the head
label_logits: Nx#class, the output of the head
bbox_regression_weights: a 4 element tensor
"""
for k, v in locals().items():
if k != 'self' and v is not None:
setattr(self, k, v)
self._bbox_class_agnostic = int(box_logits.shape[1]) == 1
@memoized
def fg_box_logits(self):
""" Returns: #fg x ? x 4 """
return tf.gather(self.box_logits, self.proposals.fg_inds(), name='fg_box_logits')
@memoized
def losses(self):
encoded_fg_gt_boxes = encode_bbox_target(
self.matched_gt_boxes_per_fg,
self.fg_input_boxes()) * self.bbox_regression_weights
self.proposals.matched_gt_boxes(),
self.proposals.fg_boxes()) * self.bbox_regression_weights
return fastrcnn_losses(
self.labels, self.label_logits,
self.proposals.labels, self.label_logits,
encoded_fg_gt_boxes, self.fg_box_logits()
)
@memoized
def decoded_output_boxes(self):
""" Returns: N x #class x 4 """
anchors = tf.tile(tf.expand_dims(self.input_boxes, 1),
anchors = tf.tile(tf.expand_dims(self.proposals.boxes, 1),
[1, cfg.DATA.NUM_CLASS, 1]) # N x #class x 4
decoded_boxes = decode_bbox_target(
self.box_logits / self.bbox_regression_weights,
......@@ -351,8 +380,7 @@ class FastRCNNHead(object):
@memoized
def decoded_output_boxes_for_true_label(self):
""" Returns: Nx4 decoded boxes """
assert self.labels is not None
return self._decoded_output_boxes_for_label(self.labels)
return self._decoded_output_boxes_for_label(self.proposals.labels)
@memoized
def decoded_output_boxes_for_predicted_label(self):
......@@ -363,13 +391,13 @@ class FastRCNNHead(object):
def decoded_output_boxes_for_label(self, labels):
assert not self._bbox_class_agnostic
indices = tf.stack([
tf.range(tf.size(self.labels, out_type=tf.int64)),
tf.range(tf.size(labels, out_type=tf.int64)),
labels
])
needed_logits = tf.gather_nd(self.box_logits, indices)
decoded = decode_bbox_target(
needed_logits / self.bbox_regression_weights,
self.input_boxes
self.proposals.boxes
)
return decoded
......@@ -379,7 +407,7 @@ class FastRCNNHead(object):
box_logits = tf.reshape(self.box_logits, [-1, 4])
decoded = decode_bbox_target(
box_logits / self.bbox_regression_weights,
self.input_boxes
self.proposals.boxes
)
return decoded
......
......@@ -34,8 +34,8 @@ from basemodel import (
import model_frcnn
import model_mrcnn
from model_frcnn import (
sample_fast_rcnn_targets,
fastrcnn_outputs, fastrcnn_predictions, FastRCNNHead)
sample_fast_rcnn_targets, fastrcnn_outputs,
fastrcnn_predictions, BoxProposals, FastRCNNHead)
from model_mrcnn import maskrcnn_upXconv_head, maskrcnn_loss
from model_rpn import rpn_head, rpn_losses, generate_rpn_proposals
from model_fpn import (
......@@ -72,27 +72,6 @@ class DetectionModel(ModelDesc):
opt = optimizer.AccumGradOptimizer(opt, 8 // cfg.TRAIN.NUM_GPUS)
return opt
def fastrcnn_inference(self, image_shape2d, fastrcnn_head):
"""
Args:
image_shape2d: h, w
fastrcnn_head (FastRCNNHead):
Returns:
boxes (mx4):
labels (m): each >= 1
"""
decoded_boxes = fastrcnn_head.decoded_output_boxes()
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
label_probs = fastrcnn_head.output_scores(name='fastrcnn_all_probs')
# indices: Nx2. Each index into (#box, #class)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
final_probs = tf.identity(final_probs, 'final_probs')
final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
final_labels = tf.gather(pred_indices, 1, axis=1, name='final_labels')
return final_boxes, final_labels
def get_inference_tensor_names(self):
"""
Returns two lists of tensor names to be used to create an inference callable.
......@@ -101,9 +80,9 @@ class DetectionModel(ModelDesc):
[str]: input names
[str]: output names
"""
out = ['final_boxes', 'final_probs', 'final_labels']
out = ['output/boxes', 'output/scores', 'output/labels']
if cfg.MODE_MASK:
out.append('final_masks')
out.append('output/masks')
return ['image'], out
......@@ -144,16 +123,13 @@ class ResNetC4Model(DetectionModel):
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels']
if is_training:
# sample proposal boxes in training
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt, name='gt_boxes_per_fg_proposal')
proposals = sample_fast_rcnn_targets(proposal_boxes, gt_boxes, gt_labels)
else:
# The boxes to be used to crop RoIs.
# Use all proposal boxes in inference
rcnn_boxes = proposal_boxes
rcnn_labels, matched_gt_boxes = None, None
proposals = BoxProposals(proposal_boxes)
boxes_on_featuremap = rcnn_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE)
boxes_on_featuremap = proposals.boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
feature_fastrcnn = resnet_conv5(roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1]) # nxcx7x7
......@@ -161,9 +137,8 @@ class ResNetC4Model(DetectionModel):
feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32),
rcnn_labels, matched_gt_boxes)
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
if is_training:
# rpn loss
......@@ -176,17 +151,17 @@ class ResNetC4Model(DetectionModel):
if cfg.MODE_MASK:
# maskrcnn loss
# In training, mask branch shares the same C5 feature.
fg_feature = tf.gather(feature_fastrcnn, fastrcnn_head.fg_inds_in_inputs())
fg_feature = tf.gather(feature_fastrcnn, proposals.fg_inds())
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', fg_feature, cfg.DATA.NUM_CATEGORY, num_convs=0) # #fg x #cat x 14x14
target_masks_for_fg = crop_and_resize(
tf.expand_dims(inputs['gt_masks'], 1),
fastrcnn_head.fg_input_boxes(),
fg_inds_wrt_gt, 14,
proposals.fg_boxes(),
proposals.fg_inds_wrt_gt, 14,
pad_border=False) # nfg x 1x14x14
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, fastrcnn_head.fg_labels(), target_masks_for_fg)
mrcnn_loss = maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg)
else:
mrcnn_loss = 0.0
......@@ -201,7 +176,11 @@ class ResNetC4Model(DetectionModel):
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
final_boxes, final_labels = self.fastrcnn_inference(image_shape2d, fastrcnn_head)
decoded_boxes = fastrcnn_head.decoded_output_boxes()
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
label_scores = fastrcnn_head.output_scores(name='fastrcnn_all_scores')
final_boxes, final_scores, final_labels = fastrcnn_predictions(
decoded_boxes, label_scores, name_scope='output')
if cfg.MODE_MASK:
roi_resized = roi_align(featuremap, final_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE), 14)
......@@ -210,7 +189,7 @@ class ResNetC4Model(DetectionModel):
'maskrcnn', feature_maskrcnn, cfg.DATA.NUM_CATEGORY, 0) # #result x #cat x 14x14
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14
tf.sigmoid(final_mask_logits, name='final_masks')
tf.sigmoid(final_mask_logits, name='output/masks')
class ResNetFPNModel(DetectionModel):
......@@ -279,23 +258,18 @@ class ResNetFPNModel(DetectionModel):
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels']
if is_training:
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
proposals = sample_fast_rcnn_targets(proposal_boxes, gt_boxes, gt_labels)
else:
# The boxes to be used to crop RoIs.
rcnn_boxes = proposal_boxes
rcnn_labels, matched_gt_boxes = None, None
proposals = BoxProposals(proposal_boxes)
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], proposals.boxes, 7)
fastrcnn_head_func = getattr(model_frcnn, cfg.FPN.FRCNN_HEAD_FUNC)
head_feature = fastrcnn_head_func('fastrcnn', roi_feature_fastrcnn)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32),
rcnn_labels, matched_gt_boxes)
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
if is_training:
# rpn loss:
......@@ -307,7 +281,7 @@ class ResNetFPNModel(DetectionModel):
if cfg.MODE_MASK:
# maskrcnn loss
roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], fastrcnn_head.fg_input_boxes(), 14,
p23456[:4], proposals.fg_boxes(), 14,
name_scope='multilevel_roi_align_mask')
maskrcnn_head_func = getattr(model_mrcnn, cfg.FPN.MRCNN_HEAD_FUNC)
mask_logits = maskrcnn_head_func(
......@@ -315,11 +289,11 @@ class ResNetFPNModel(DetectionModel):
target_masks_for_fg = crop_and_resize(
tf.expand_dims(inputs['gt_masks'], 1),
fastrcnn_head.fg_input_boxes(),
fg_inds_wrt_gt, 28,
proposals.fg_boxes(),
proposals.fg_inds_wrt_gt, 28,
pad_border=False) # fg x 1x28x28
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, fastrcnn_head.fg_labels(), target_masks_for_fg)
mrcnn_loss = maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg)
else:
mrcnn_loss = 0.0
......@@ -333,7 +307,11 @@ class ResNetFPNModel(DetectionModel):
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
final_boxes, final_labels = self.fastrcnn_inference(image_shape2d, fastrcnn_head)
decoded_boxes = fastrcnn_head.decoded_output_boxes()
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
label_scores = fastrcnn_head.output_scores(name='fastrcnn_all_scores')
final_boxes, final_scores, final_labels = fastrcnn_predictions(
decoded_boxes, label_scores, name_scope='output')
if cfg.MODE_MASK:
# Cascade inference needs roi transform with refined boxes.
roi_feature_maskrcnn = multilevel_roi_align(p23456[:4], final_boxes, 14)
......@@ -342,7 +320,7 @@ class ResNetFPNModel(DetectionModel):
'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY) # #fg x #cat x 28 x 28
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28
tf.sigmoid(final_mask_logits, name='final_masks')
tf.sigmoid(final_mask_logits, name='output/masks')
def visualize(model, model_path, nr_visualize=100, output_dir='output'):
......@@ -358,11 +336,11 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'):
input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[
'generate_{}_proposals/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'generate_{}_proposals/probs'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'fastrcnn_all_probs',
'final_boxes',
'final_probs',
'final_labels',
'generate_{}_proposals/scores'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'fastrcnn_all_scores',
'output/boxes',
'output/scores',
'output/labels',
]))
if os.path.isdir(output_dir):
......@@ -376,18 +354,18 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'):
else:
gt_boxes, gt_labels = dp[-2:]
rpn_boxes, rpn_scores, all_probs, \
final_boxes, final_probs, final_labels = pred(img, gt_boxes, gt_labels)
rpn_boxes, rpn_scores, all_scores, \
final_boxes, final_scores, final_labels = pred(img, gt_boxes, gt_labels)
# draw groundtruth boxes
gt_viz = draw_annotation(img, gt_boxes, gt_labels)
# draw best proposals for each groundtruth, to show recall
proposal_viz, good_proposals_ind = draw_proposal_recall(img, rpn_boxes, rpn_scores, gt_boxes)
# draw the scores for the above proposals
score_viz = draw_predictions(img, rpn_boxes[good_proposals_ind], all_probs[good_proposals_ind])
score_viz = draw_predictions(img, rpn_boxes[good_proposals_ind], all_scores[good_proposals_ind])
results = [DetectionResult(*args) for args in
zip(final_boxes, final_probs, final_labels,
zip(final_boxes, final_scores, final_labels,
[None] * len(final_labels))]
final_viz = draw_final_outputs(img, results)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment