Commit 8d2febb7 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] RPNAnchors as a class

parent e61946b2
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# File: model_box.py # File: model_box.py
import numpy as np import numpy as np
from collections import namedtuple
import tensorflow as tf import tensorflow as tf
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
...@@ -170,6 +171,32 @@ def roi_align(featuremap, boxes, resolution): ...@@ -170,6 +171,32 @@ def roi_align(featuremap, boxes, resolution):
return ret return ret
class RPNAnchors(namedtuple('_RPNAnchors', ['boxes', 'gt_labels', 'gt_boxes'])):
"""
boxes (FS x FS x NA x 4): The anchor boxes.
gt_labels (FS x FS x NA):
gt_boxes (FS x FS x NA x 4): Groundtruth boxes corresponding to each anchor.
"""
def encoded_gt_boxes(self):
return encode_bbox_target(self.gt_boxes, self.boxes)
def decode_logits(self, logits):
return decode_bbox_target(logits, self.boxes)
@under_name_scope()
def narrow_to(self, featuremap):
"""
Slice anchors to the spatial size of this featuremap.
"""
shape2d = tf.shape(featuremap)[2:] # h,w
slice3d = tf.concat([shape2d, [-1]], axis=0)
slice4d = tf.concat([shape2d, [-1, -1]], axis=0)
boxes = tf.slice(self.boxes, [0, 0, 0, 0], slice4d)
gt_labels = tf.slice(self.gt_labels, [0, 0, 0], slice3d)
gt_boxes = tf.slice(self.gt_boxes, [0, 0, 0, 0], slice4d)
return RPNAnchors(boxes, gt_labels, gt_boxes)
if __name__ == '__main__': if __name__ == '__main__':
""" """
Demonstrate what's wrong with tf.image.crop_and_resize: Demonstrate what's wrong with tf.image.crop_and_resize:
......
...@@ -21,7 +21,6 @@ assert six.PY3, "FasterRCNN requires Python 3!" ...@@ -21,7 +21,6 @@ assert six.PY3, "FasterRCNN requires Python 3!"
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
from tensorpack.tfutils.common import get_tf_version_number from tensorpack.tfutils.common import get_tf_version_number
import tensorpack.utils.viz as tpviz import tensorpack.utils.viz as tpviz
...@@ -38,7 +37,8 @@ from model import ( ...@@ -38,7 +37,8 @@ from model import (
maskrcnn_upXconv_head, maskrcnn_loss, maskrcnn_upXconv_head, maskrcnn_loss,
fpn_model, multilevel_roi_align) fpn_model, multilevel_roi_align)
from model_box import ( from model_box import (
clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize, roi_align) clip_boxes, decode_bbox_target, encode_bbox_target,
crop_and_resize, roi_align, RPNAnchors)
from data import ( from data import (
get_train_dataflow, get_eval_dataflow, get_train_dataflow, get_eval_dataflow,
get_all_anchors, get_all_anchors_fpn) get_all_anchors, get_all_anchors_fpn)
...@@ -56,24 +56,6 @@ class DetectionModel(ModelDesc): ...@@ -56,24 +56,6 @@ class DetectionModel(ModelDesc):
image = image_preprocess(image, bgr=True) image = image_preprocess(image, bgr=True)
return tf.transpose(image, [0, 3, 1, 2]) return tf.transpose(image, [0, 3, 1, 2])
@under_name_scope()
def narrow_to_featuremap(self, featuremap, anchors, anchor_labels, anchor_boxes):
"""
Slice anchors/anchor_labels/anchor_boxes to the spatial size of this featuremap.
Args:
anchors (FS x FS x NA x 4):
anchor_labels (FS x FS x NA):
anchor_boxes (FS x FS x NA x 4):
"""
shape2d = tf.shape(featuremap)[2:] # h,w
slice3d = tf.concat([shape2d, [-1]], axis=0)
slice4d = tf.concat([shape2d, [-1, -1]], axis=0)
anchors = tf.slice(anchors, [0, 0, 0, 0], slice4d)
anchor_labels = tf.slice(anchor_labels, [0, 0, 0], slice3d)
anchor_boxes = tf.slice(anchor_boxes, [0, 0, 0, 0], slice4d)
return anchors, anchor_labels, anchor_boxes
def optimizer(self): def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False) lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
tf.summary.scalar('learning_rate-summary', lr) tf.summary.scalar('learning_rate-summary', lr)
...@@ -183,12 +165,11 @@ class ResNetC4Model(DetectionModel): ...@@ -183,12 +165,11 @@ class ResNetC4Model(DetectionModel):
featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3]) featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, cfg.RPN.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, cfg.RPN.NUM_ANCHOR)
fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap( anchors = RPNAnchors(get_all_anchors(), anchor_labels, anchor_boxes)
featuremap, get_all_anchors(), anchor_labels, anchor_boxes) anchors = anchors.narrow_to(featuremap)
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
pred_boxes_decoded = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox pred_boxes_decoded = anchors.decode_logits(rpn_box_logits) # fHxfWxNAx4, floatbox
proposal_boxes, proposal_scores = generate_rpn_proposals( proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]), tf.reshape(rpn_label_logits, [-1]),
...@@ -216,7 +197,7 @@ class ResNetC4Model(DetectionModel): ...@@ -216,7 +197,7 @@ class ResNetC4Model(DetectionModel):
if is_training: if is_training:
# rpn loss # rpn loss
rpn_label_loss, rpn_box_loss = rpn_losses( rpn_label_loss, rpn_box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits)
# fastrcnn loss # fastrcnn loss
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
...@@ -274,6 +255,7 @@ class ResNetC4Model(DetectionModel): ...@@ -274,6 +255,7 @@ class ResNetC4Model(DetectionModel):
class ResNetFPNModel(DetectionModel): class ResNetFPNModel(DetectionModel):
def inputs(self): def inputs(self):
ret = [ ret = [
tf.placeholder(tf.float32, (None, None, 3), 'image')] tf.placeholder(tf.float32, (None, None, 3), 'image')]
...@@ -293,14 +275,28 @@ class ResNetFPNModel(DetectionModel): ...@@ -293,14 +275,28 @@ class ResNetFPNModel(DetectionModel):
) # NR_GT x height x width ) # NR_GT x height x width
return ret return ret
def slice_feature_and_anchors(self, image_shape2d, p23456, anchors):
for i, stride in enumerate(cfg.FPN.ANCHOR_STRIDES):
with tf.name_scope('FPN_slice_lvl{}'.format(i)):
if i < 3:
# Images are padded for p5, which are too large for p2-p4.
# This seems to have no effect on mAP.
pi = p23456[i]
target_shape = tf.to_int32(tf.ceil(tf.to_float(image_shape2d) * (1.0 / stride)))
p23456[i] = tf.slice(pi, [0, 0, 0, 0],
tf.concat([[-1, -1], target_shape], axis=0))
p23456[i].set_shape([1, pi.shape[1], None, None])
anchors[i] = anchors[i].narrow_to(p23456[i])
def build_graph(self, *inputs): def build_graph(self, *inputs):
num_fpn_level = len(cfg.FPN.ANCHOR_STRIDES) num_fpn_level = len(cfg.FPN.ANCHOR_STRIDES)
assert len(cfg.RPN.ANCHOR_SIZES) == num_fpn_level assert len(cfg.RPN.ANCHOR_SIZES) == num_fpn_level
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
image = inputs[0] image = inputs[0]
input_anchors = inputs[1: 1 + 2 * num_fpn_level] input_anchors = inputs[1: 1 + 2 * num_fpn_level]
multilevel_anchor_labels = input_anchors[0::2] multilevel_anchors = [RPNAnchors(*args) for args in
multilevel_anchor_boxes = input_anchors[1::2] zip(get_all_anchors_fpn(), input_anchors[0::2], input_anchors[1::2])]
gt_boxes, gt_labels = inputs[11], inputs[12] gt_boxes, gt_labels = inputs[11], inputs[12]
if cfg.MODE_MASK: if cfg.MODE_MASK:
gt_masks = inputs[-1] gt_masks = inputs[-1]
...@@ -310,15 +306,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -310,15 +306,7 @@ class ResNetFPNModel(DetectionModel):
c2345 = resnet_fpn_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK) c2345 = resnet_fpn_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK)
p23456 = fpn_model('fpn', c2345) p23456 = fpn_model('fpn', c2345)
self.slice_feature_and_anchors(image_shape2d, p23456, multilevel_anchors)
# Images are padded for p5, which are too large for p2-p4.
# This seems to have no effect on mAP.
for i, stride in enumerate(cfg.FPN.ANCHOR_STRIDES[:3]):
pi = p23456[i]
target_shape = tf.to_int32(tf.ceil(tf.to_float(image_shape2d) * (1.0 / stride)))
p23456[i] = tf.slice(pi, [0, 0, 0, 0],
tf.concat([[-1, -1], target_shape], axis=0))
p23456[i].set_shape([1, pi.shape[1], None, None])
# Multi-Level RPN Proposals # Multi-Level RPN Proposals
multilevel_proposals = [] multilevel_proposals = []
...@@ -327,13 +315,8 @@ class ResNetFPNModel(DetectionModel): ...@@ -327,13 +315,8 @@ class ResNetFPNModel(DetectionModel):
rpn_label_logits, rpn_box_logits = rpn_head( rpn_label_logits, rpn_box_logits = rpn_head(
'rpn', p23456[lvl], cfg.FPN.NUM_CHANNEL, len(cfg.RPN.ANCHOR_RATIOS)) 'rpn', p23456[lvl], cfg.FPN.NUM_CHANNEL, len(cfg.RPN.ANCHOR_RATIOS))
with tf.name_scope('FPN_lvl{}'.format(lvl + 2)): with tf.name_scope('FPN_lvl{}'.format(lvl + 2)):
anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2)) anchors = multilevel_anchors[lvl]
anchors, anchor_labels, anchor_boxes = \ pred_boxes_decoded = anchors.decode_logits(rpn_box_logits)
self.narrow_to_featuremap(p23456[lvl], anchors,
multilevel_anchor_labels[lvl],
multilevel_anchor_boxes[lvl])
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, anchors)
pred_boxes_decoded = decode_bbox_target(rpn_box_logits, anchors)
proposal_boxes, proposal_scores = generate_rpn_proposals( proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]), tf.reshape(rpn_label_logits, [-1]),
...@@ -342,7 +325,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -342,7 +325,7 @@ class ResNetFPNModel(DetectionModel):
multilevel_proposals.append((proposal_boxes, proposal_scores)) multilevel_proposals.append((proposal_boxes, proposal_scores))
if is_training: if is_training:
label_loss, box_loss = rpn_losses( label_loss, box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded, anchors.gt_labels, anchors.encoded_gt_boxes(),
rpn_label_logits, rpn_box_logits) rpn_label_logits, rpn_box_logits)
rpn_loss_collection.extend([label_loss, box_loss]) rpn_loss_collection.extend([label_loss, box_loss])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment