Commit 99ba595d authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use FastRCNNHead to manage head inputs&outputs

parent 02eb02e2
......@@ -8,9 +8,11 @@ from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import (
Conv2D, FullyConnected, layer_register)
from tensorpack.utils.argtools import memoized
from basemodel import GroupNorm
from utils.box_ops import pairwise_iou
from model_box import encode_bbox_target, decode_bbox_target
from config import config as cfg
......@@ -113,7 +115,7 @@ def fastrcnn_outputs(feature, num_classes):
box_regression = FullyConnected(
'box', feature, num_classes * 4,
kernel_initializer=tf.random_normal_initializer(stddev=0.001))
box_regression = tf.reshape(box_regression, (-1, num_classes, 4))
box_regression = tf.reshape(box_regression, (-1, num_classes, 4), name='output_box')
return classification, box_regression
......@@ -125,6 +127,9 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
label_logits: nxC
fg_boxes: nfgx4, encoded
fg_box_logits: nfgxCx4
Returns:
label_loss, box_loss
"""
label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=label_logits)
......@@ -132,10 +137,9 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
fg_inds = tf.where(labels > 0)[:, 0]
fg_labels = tf.gather(labels, fg_inds)
num_fg = tf.size(fg_inds)
num_fg = tf.size(fg_inds, out_type=tf.int64)
indices = tf.stack(
[tf.range(num_fg),
tf.to_int32(fg_labels)], axis=1) # #fgx2
[tf.range(num_fg), fg_labels], axis=1) # #fgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices)
with tf.name_scope('label_metrics'), tf.device('/cpu:0'):
......@@ -143,7 +147,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
correct = tf.to_float(tf.equal(prediction, labels)) # boolean/integer gather is unavailable on GPU
accuracy = tf.reduce_mean(correct, name='accuracy')
fg_label_pred = tf.argmax(tf.gather(label_logits, fg_inds), axis=1)
num_zero = tf.reduce_sum(tf.to_int32(tf.equal(fg_label_pred, 0)), name='num_zero')
num_zero = tf.reduce_sum(tf.to_int64(tf.equal(fg_label_pred, 0)), name='num_zero')
false_negative = tf.truediv(num_zero, num_fg, name='false_negative')
fg_accuracy = tf.reduce_mean(
tf.gather(correct, fg_inds), name='fg_accuracy')
......@@ -163,12 +167,17 @@ def fastrcnn_predictions(boxes, probs):
Generate final results from predictions of all proposals.
Args:
boxes: n#catx4 floatbox in float32
boxes: n#classx4 floatbox in float32
probs: nx#class
Returns:
indices: Kx2. Each is (box_id, class_id)
probs: K floats
"""
assert boxes.shape[1] == cfg.DATA.NUM_CATEGORY
assert boxes.shape[1] == cfg.DATA.NUM_CLASS
assert probs.shape[1] == cfg.DATA.NUM_CLASS
boxes = tf.transpose(boxes, [1, 0, 2]) # #catxnx4
boxes = tf.transpose(boxes, [1, 0, 2])[1:, :, :] # #catxnx4
boxes.set_shape([None, cfg.DATA.NUM_CATEGORY, None])
probs = tf.transpose(probs[:, 1:], [1, 0]) # #catxn
def f(X):
......@@ -209,8 +218,9 @@ def fastrcnn_predictions(boxes, probs):
tf.minimum(cfg.TEST.RESULTS_PER_IM, tf.size(probs)),
sorted=False)
filtered_selection = tf.gather(selected_indices, topk_indices)
filtered_selection = tf.reverse(filtered_selection, axis=[1], name='filtered_indices')
return filtered_selection, topk_probs
cat_ids, box_ids = tf.unstack(filtered_selection, axis=1)
final_ids = tf.stack([box_ids, cat_ids + 1], axis=1, name='final_ids') # Kx2, each is (box_id, class_id)
return final_ids, topk_probs
"""
......@@ -267,3 +277,85 @@ def fastrcnn_4conv1fc_head(*args, **kwargs):
def fastrcnn_4conv1fc_gn_head(*args, **kwargs):
return fastrcnn_Xconv1fc_head(*args, num_convs=4, norm='GN', **kwargs)
class FastRCNNHead(object):
"""
A class to process & decode inputs/outputs of a fastrcnn classification+regression head.
"""
def __init__(self, input_boxes, box_logits, label_logits, bbox_regression_weights,
labels=None, matched_gt_boxes_per_fg=None):
"""
Args:
input_boxes: Nx4, inputs to the head
box_logits: Nx#classx4, the output of the head
label_logits: Nx#class, the output of the head
bbox_regression_weights: a 4 element tensor
labels: N, each in [0, #class-1], the true label for each input box
matched_gt_boxes_per_fg: #fgx4, the matching gt boxes for each fg input box
The last two arguments could be None when not training.
"""
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
@memoized
def fg_inds_in_inputs(self):
""" Returns: #fg indices in [0, N-1] """
return tf.reshape(tf.where(self.labels > 0), [-1], name='fg_inds_in_inputs')
@memoized
def fg_input_boxes(self):
""" Returns: #fgx4 """
return tf.gather(self.input_boxes, self.fg_inds_in_inputs(), name='fg_input_boxes')
@memoized
def fg_box_logits(self):
""" Returns: #fg x #class x 4 """
return tf.gather(self.box_logits, self.fg_inds_in_inputs(), name='fg_box_logits')
@memoized
def fg_labels(self):
""" Returns: #fg """
return tf.gather(self.labels, self.fg_inds_in_inputs(), name='fg_labels')
@memoized
def losses(self):
encoded_fg_gt_boxes = encode_bbox_target(
self.matched_gt_boxes_per_fg,
self.fg_input_boxes()) * self.bbox_regression_weights
return fastrcnn_losses(
self.labels, self.label_logits,
encoded_fg_gt_boxes, self.fg_box_logits()
)
@memoized
def decoded_output_boxes(self):
""" Returns: N x #class x 4 """
anchors = tf.tile(tf.expand_dims(self.input_boxes, 1),
[1, cfg.DATA.NUM_CLASS, 1]) # N x #class x 4
decoded_boxes = decode_bbox_target(
self.box_logits / self.bbox_regression_weights,
anchors
)
return decoded_boxes
@memoized
def decoded_output_boxes_for_true_label(self):
""" Returns: Nx4 decoded boxes """
indices = tf.stack([
tf.range(tf.size(self.labels, out_type=tf.int64)),
self.labels
])
needed_logits = tf.gather_nd(self.box_logits, indices)
decoded = decode_bbox_target(
needed_logits / self.bbox_regression_weights,
self.input_boxes
)
return decoded
@memoized
def output_scores(self, name=None):
""" Returns: N x #class scores, summed to one for each box."""
return tf.nn.softmax(self.label_logits, name=name)
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment