Commit 1ae8b540 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] refactor into generalized rcnn

parent 55603b75
...@@ -91,9 +91,9 @@ Mask R-CNN results contain both box and mask mAP. ...@@ -91,9 +91,9 @@ Mask R-CNN results contain both box and mask mAP.
| R50-FPN | 42.0;36.3 | | 41h | <details><summary>+Cascade</summary>`MODE_FPN=True FPN.CASCADE=True` </details> | | R50-FPN | 42.0;36.3 | | 41h | <details><summary>+Cascade</summary>`MODE_FPN=True FPN.CASCADE=True` </details> |
| R50-FPN | 39.5;35.2 | 39.5;34.4<sup>[2](#ft2)</sup> | 33h | <details><summary>+ConvGNHead</summary>`MODE_FPN=True`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head` </details> | | R50-FPN | 39.5;35.2 | 39.5;34.4<sup>[2](#ft2)</sup> | 33h | <details><summary>+ConvGNHead</summary>`MODE_FPN=True`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head` </details> |
| R50-FPN | 40.0;36.2 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-StandardGN.npz) | 40.3;35.7 | 40h | <details><summary>+GN</summary>`MODE_FPN=True`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head` | | R50-FPN | 40.0;36.2 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-StandardGN.npz) | 40.3;35.7 | 40h | <details><summary>+GN</summary>`MODE_FPN=True`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head` |
| R101-C4 | 41.4;35.2 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R101C4-MaskRCNN-Standard.npz) | | 60h | <details><summary>standard</summary>`BACKBONE.RESNET_NUM_BLOCK=[3,4,23,3]` </details> | | R101-C4 | 41.4;35.2 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R101C4-MaskRCNN-Standard.npz) | | 60h | <details><summary>standard</summary>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 40.4;36.6 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-Standard.npz) | 40.9;36.4 | 38h | <details><summary>standard</summary>`MODE_FPN=True`<br/>`BACKBONE.RESNET_NUM_BLOCK=[3,4,23,3]` </details> | | R101-FPN | 40.4;36.6 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-Standard.npz) | 40.9;36.4 | 38h | <details><summary>standard</summary>`MODE_FPN=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 46.5;40.1 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-BetterParams.npz) <sup>[3](#ft3)</sup> | | 73h | <details><summary>+++</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCK=[3,4,23,3]`<br/>`TEST.RESULT_SCORE_THRESH=1e-4`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[420000,500000,540000]` </details> | | R101-FPN | 46.5;40.1 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-BetterParams.npz) <sup>[3](#ft3)</sup> | | 73h | <details><summary>+++</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`TEST.RESULT_SCORE_THRESH=1e-4`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[420000,500000,540000]` </details> |
<a id="ft1">1</a>: Here we comapre models that have identical training & inference cost between the two implementation. However their numbers are different due to many small implementation details. <a id="ft1">1</a>: Here we comapre models that have identical training & inference cost between the two implementation. However their numbers are different due to many small implementation details.
......
...@@ -86,8 +86,8 @@ _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, to be populated ...@@ -86,8 +86,8 @@ _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, to be populated
# basemodel ---------------------- # basemodel ----------------------
_C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz _C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz
_C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50 _C.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101 # RESNET_NUM_BLOCKS = [3, 4, 23, 3] # for resnet101
_C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers _C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers
_C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN _C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN
_C.BACKBONE.FREEZE_AT = 2 # options: 0, 1, 2 _C.BACKBONE.FREEZE_AT = 2 # options: 0, 1, 2
......
...@@ -160,7 +160,7 @@ def multilevel_rpn_losses( ...@@ -160,7 +160,7 @@ def multilevel_rpn_losses(
total_label_loss = tf.add_n(losses[::2], name='label_loss') total_label_loss = tf.add_n(losses[::2], name='label_loss')
total_box_loss = tf.add_n(losses[1::2], name='box_loss') total_box_loss = tf.add_n(losses[1::2], name='box_loss')
add_moving_summary(total_label_loss, total_box_loss) add_moving_summary(total_label_loss, total_box_loss)
return total_label_loss, total_box_loss return [total_label_loss, total_box_loss]
@under_name_scope() @under_name_scope()
...@@ -179,11 +179,11 @@ def generate_fpn_proposals( ...@@ -179,11 +179,11 @@ def generate_fpn_proposals(
assert len(multilevel_pred_boxes) == num_lvl assert len(multilevel_pred_boxes) == num_lvl
assert len(multilevel_label_logits) == num_lvl assert len(multilevel_label_logits) == num_lvl
ctx = get_current_tower_context() training = get_current_tower_context().is_training
all_boxes = [] all_boxes = []
all_scores = [] all_scores = []
if cfg.FPN.PROPOSAL_MODE == 'Level': if cfg.FPN.PROPOSAL_MODE == 'Level':
fpn_nms_topk = cfg.RPN.TRAIN_PER_LEVEL_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PER_LEVEL_NMS_TOPK fpn_nms_topk = cfg.RPN.TRAIN_PER_LEVEL_NMS_TOPK if training else cfg.RPN.TEST_PER_LEVEL_NMS_TOPK
for lvl in range(num_lvl): for lvl in range(num_lvl):
with tf.name_scope('Lvl{}'.format(lvl + 2)): with tf.name_scope('Lvl{}'.format(lvl + 2)):
pred_boxes_decoded = multilevel_pred_boxes[lvl] pred_boxes_decoded = multilevel_pred_boxes[lvl]
...@@ -210,8 +210,8 @@ def generate_fpn_proposals( ...@@ -210,8 +210,8 @@ def generate_fpn_proposals(
all_scores = tf.concat(all_scores, axis=0) all_scores = tf.concat(all_scores, axis=0)
proposal_boxes, proposal_scores = generate_rpn_proposals( proposal_boxes, proposal_scores = generate_rpn_proposals(
all_boxes, all_scores, image_shape2d, all_boxes, all_scores, image_shape2d,
cfg.RPN.TRAIN_PRE_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PRE_NMS_TOPK, cfg.RPN.TRAIN_PRE_NMS_TOPK if training else cfg.RPN.TEST_PRE_NMS_TOPK,
cfg.RPN.TRAIN_POST_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_POST_NMS_TOPK) cfg.RPN.TRAIN_POST_NMS_TOPK if training else cfg.RPN.TEST_POST_NMS_TOPK)
tf.sigmoid(proposal_scores, name='probs') # for visualization tf.sigmoid(proposal_scores, name='probs') # for visualization
return tf.stop_gradient(proposal_boxes, name='boxes'), \ return tf.stop_gradient(proposal_boxes, name='boxes'), \
......
...@@ -171,7 +171,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -171,7 +171,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
add_moving_summary(label_loss, box_loss, accuracy, add_moving_summary(label_loss, box_loss, accuracy,
fg_accuracy, false_negative, tf.to_float(num_fg, name='num_fg_label')) fg_accuracy, false_negative, tf.to_float(num_fg, name='num_fg_label'))
return label_loss, box_loss return [label_loss, box_loss]
@under_name_scope() @under_name_scope()
......
...@@ -98,7 +98,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits): ...@@ -98,7 +98,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
box_loss = tf.where(tf.equal(nr_pos, 0), placeholder, box_loss, name='box_loss') box_loss = tf.where(tf.equal(nr_pos, 0), placeholder, box_loss, name='box_loss')
add_moving_summary(label_loss, box_loss, nr_valid, nr_pos) add_moving_summary(label_loss, box_loss, nr_valid, nr_pos)
return label_loss, box_loss return [label_loss, box_loss]
@under_name_scope() @under_name_scope()
......
...@@ -62,6 +62,10 @@ class DetectionModel(ModelDesc): ...@@ -62,6 +62,10 @@ class DetectionModel(ModelDesc):
image = image_preprocess(image, bgr=True) image = image_preprocess(image, bgr=True)
return tf.transpose(image, [0, 3, 1, 2]) return tf.transpose(image, [0, 3, 1, 2])
@property
def training(self):
return get_current_tower_context().is_training
def optimizer(self): def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False) lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
tf.summary.scalar('learning_rate-summary', lr) tf.summary.scalar('learning_rate-summary', lr)
...@@ -86,6 +90,28 @@ class DetectionModel(ModelDesc): ...@@ -86,6 +90,28 @@ class DetectionModel(ModelDesc):
out.append('output/masks') out.append('output/masks')
return ['image'], out return ['image'], out
def build_graph(self, *inputs):
inputs = dict(zip(self.input_names, inputs))
image = self.preprocess(inputs['image']) # 1CHW
features = self.backbone(image)
proposals, rpn_losses = self.rpn(image, features, inputs) # inputs?
targets = [inputs['gt_boxes'], inputs['gt_labels']]
if 'gt_masks' in inputs:
targets.append(inputs['gt_masks'])
head_losses = self.roi_heads(image, features, proposals, targets)
if self.training:
wd_cost = regularize_cost(
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
total_cost = tf.add_n(
rpn_losses + head_losses + [wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
class ResNetC4Model(DetectionModel): class ResNetC4Model(DetectionModel):
def inputs(self): def inputs(self):
...@@ -101,15 +127,12 @@ class ResNetC4Model(DetectionModel): ...@@ -101,15 +127,12 @@ class ResNetC4Model(DetectionModel):
) # NR_GT x height x width ) # NR_GT x height x width
return ret return ret
def build_graph(self, *inputs): def backbone(self, image):
# TODO need to make tensorpack handles dict better return [resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCKS[:3])]
inputs = dict(zip(self.input_names, inputs))
is_training = get_current_tower_context().is_training
image = self.preprocess(inputs['image']) # 1CHW
featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3]) def rpn(self, image, features, inputs):
featuremap = features[0]
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR)
anchors = RPNAnchors(get_all_anchors(), inputs['anchor_labels'], inputs['anchor_boxes']) anchors = RPNAnchors(get_all_anchors(), inputs['anchor_labels'], inputs['anchor_boxes'])
anchors = anchors.narrow_to(featuremap) anchors = anchors.narrow_to(featuremap)
...@@ -119,22 +142,33 @@ class ResNetC4Model(DetectionModel): ...@@ -119,22 +142,33 @@ class ResNetC4Model(DetectionModel):
tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]), tf.reshape(rpn_label_logits, [-1]),
image_shape2d, image_shape2d,
cfg.RPN.TRAIN_PRE_NMS_TOPK if is_training else cfg.RPN.TEST_PRE_NMS_TOPK, cfg.RPN.TRAIN_PRE_NMS_TOPK if self.training else cfg.RPN.TEST_PRE_NMS_TOPK,
cfg.RPN.TRAIN_POST_NMS_TOPK if is_training else cfg.RPN.TEST_POST_NMS_TOPK) cfg.RPN.TRAIN_POST_NMS_TOPK if self.training else cfg.RPN.TEST_POST_NMS_TOPK)
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels'] if self.training:
if is_training: losses = rpn_losses(
# sample proposal boxes in training anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits)
proposals = sample_fast_rcnn_targets(proposal_boxes, gt_boxes, gt_labels)
else: else:
# The boxes to be used to crop RoIs. losses = []
# Use all proposal boxes in inference
proposals = BoxProposals(proposal_boxes) return BoxProposals(proposal_boxes), losses
def roi_heads(self, image, features, proposals, targets):
image_shape2d = tf.shape(image)[2:] # h,w
featuremap = features[0]
gt_boxes, gt_labels, *_ = targets
if self.training:
# sample proposal boxes in training
proposals = sample_fast_rcnn_targets(proposals.boxes, gt_boxes, gt_labels)
# The boxes to be used to crop RoIs.
# Use all proposal boxes in inference
boxes_on_featuremap = proposals.boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE) boxes_on_featuremap = proposals.boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
feature_fastrcnn = resnet_conv5(roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1]) # nxcx7x7 feature_fastrcnn = resnet_conv5(roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCKS[-1]) # nxcx7x7
# Keep C5 feature to be shared with mask branch # Keep C5 feature to be shared with mask branch
feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first') feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS)
...@@ -142,16 +176,11 @@ class ResNetC4Model(DetectionModel): ...@@ -142,16 +176,11 @@ class ResNetC4Model(DetectionModel):
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits, fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32)) tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
if is_training: if self.training:
all_losses = [] all_losses = fastrcnn_head.losses()
# rpn loss
all_losses.extend(rpn_losses(
anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits))
# fastrcnn loss
all_losses.extend(fastrcnn_head.losses())
if cfg.MODE_MASK: if cfg.MODE_MASK:
gt_masks = targets[2]
# maskrcnn loss # maskrcnn loss
# In training, mask branch shares the same C5 feature. # In training, mask branch shares the same C5 feature.
fg_feature = tf.gather(feature_fastrcnn, proposals.fg_inds()) fg_feature = tf.gather(feature_fastrcnn, proposals.fg_inds())
...@@ -159,20 +188,13 @@ class ResNetC4Model(DetectionModel): ...@@ -159,20 +188,13 @@ class ResNetC4Model(DetectionModel):
'maskrcnn', fg_feature, cfg.DATA.NUM_CATEGORY, num_convs=0) # #fg x #cat x 14x14 'maskrcnn', fg_feature, cfg.DATA.NUM_CATEGORY, num_convs=0) # #fg x #cat x 14x14
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(inputs['gt_masks'], 1), tf.expand_dims(gt_masks, 1),
proposals.fg_boxes(), proposals.fg_boxes(),
proposals.fg_inds_wrt_gt, 14, proposals.fg_inds_wrt_gt, 14,
pad_border=False) # nfg x 1x14x14 pad_border=False) # nfg x 1x14x14
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
all_losses.append(maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg)) all_losses.append(maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg))
return all_losses
wd_cost = regularize_cost(
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
all_losses.append(wd_cost)
total_cost = tf.add_n(all_losses, 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
else: else:
decoded_boxes = fastrcnn_head.decoded_output_boxes() decoded_boxes = fastrcnn_head.decoded_output_boxes()
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes') decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
...@@ -182,12 +204,13 @@ class ResNetC4Model(DetectionModel): ...@@ -182,12 +204,13 @@ class ResNetC4Model(DetectionModel):
if cfg.MODE_MASK: if cfg.MODE_MASK:
roi_resized = roi_align(featuremap, final_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE), 14) roi_resized = roi_align(featuremap, final_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE), 14)
feature_maskrcnn = resnet_conv5(roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1]) feature_maskrcnn = resnet_conv5(roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCKS[-1])
mask_logits = maskrcnn_upXconv_head( mask_logits = maskrcnn_upXconv_head(
'maskrcnn', feature_maskrcnn, cfg.DATA.NUM_CATEGORY, 0) # #result x #cat x 14x14 'maskrcnn', feature_maskrcnn, cfg.DATA.NUM_CATEGORY, 0) # #result x #cat x 14x14
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14 final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14
tf.sigmoid(final_mask_logits, name='output/masks') tf.sigmoid(final_mask_logits, name='output/masks')
return []
class ResNetFPNModel(DetectionModel): class ResNetFPNModel(DetectionModel):
...@@ -225,28 +248,25 @@ class ResNetFPNModel(DetectionModel): ...@@ -225,28 +248,25 @@ class ResNetFPNModel(DetectionModel):
anchors[i] = anchors[i].narrow_to(p23456[i]) anchors[i] = anchors[i].narrow_to(p23456[i])
def build_graph(self, *inputs): def backbone(self, image):
inputs = dict(zip(self.input_names, inputs)) c2345 = resnet_fpn_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCKS)
num_fpn_level = len(cfg.FPN.ANCHOR_STRIDES) p23456 = fpn_model('fpn', c2345)
assert len(cfg.RPN.ANCHOR_SIZES) == num_fpn_level return p23456
is_training = get_current_tower_context().is_training
def rpn(self, image, features, inputs):
assert len(cfg.RPN.ANCHOR_SIZES) == len(cfg.FPN.ANCHOR_STRIDES)
image_shape2d = tf.shape(image)[2:] # h,w
all_anchors_fpn = get_all_anchors_fpn() all_anchors_fpn = get_all_anchors_fpn()
multilevel_anchors = [RPNAnchors( multilevel_anchors = [RPNAnchors(
all_anchors_fpn[i], all_anchors_fpn[i],
inputs['anchor_labels_lvl{}'.format(i + 2)], inputs['anchor_labels_lvl{}'.format(i + 2)],
inputs['anchor_boxes_lvl{}'.format(i + 2)]) for i in range(len(all_anchors_fpn))] inputs['anchor_boxes_lvl{}'.format(i + 2)]) for i in range(len(all_anchors_fpn))]
self.slice_feature_and_anchors(image_shape2d, features, multilevel_anchors)
image = self.preprocess(inputs['image']) # 1CHW
image_shape2d = tf.shape(image)[2:] # h,w
c2345 = resnet_fpn_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK)
p23456 = fpn_model('fpn', c2345)
self.slice_feature_and_anchors(image_shape2d, p23456, multilevel_anchors)
# Multi-Level RPN Proposals # Multi-Level RPN Proposals
rpn_outputs = [rpn_head('rpn', pi, cfg.FPN.NUM_CHANNEL, len(cfg.RPN.ANCHOR_RATIOS)) rpn_outputs = [rpn_head('rpn', pi, cfg.FPN.NUM_CHANNEL, len(cfg.RPN.ANCHOR_RATIOS))
for pi in p23456] for pi in features]
multilevel_label_logits = [k[0] for k in rpn_outputs] multilevel_label_logits = [k[0] for k in rpn_outputs]
multilevel_box_logits = [k[1] for k in rpn_outputs] multilevel_box_logits = [k[1] for k in rpn_outputs]
multilevel_pred_boxes = [anchor.decode_logits(logits) multilevel_pred_boxes = [anchor.decode_logits(logits)
...@@ -255,15 +275,25 @@ class ResNetFPNModel(DetectionModel): ...@@ -255,15 +275,25 @@ class ResNetFPNModel(DetectionModel):
proposal_boxes, proposal_scores = generate_fpn_proposals( proposal_boxes, proposal_scores = generate_fpn_proposals(
multilevel_pred_boxes, multilevel_label_logits, image_shape2d) multilevel_pred_boxes, multilevel_label_logits, image_shape2d)
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels'] if self.training:
if is_training: losses = multilevel_rpn_losses(
proposals = sample_fast_rcnn_targets(proposal_boxes, gt_boxes, gt_labels) multilevel_anchors, multilevel_label_logits, multilevel_box_logits)
else: else:
proposals = BoxProposals(proposal_boxes) losses = []
return BoxProposals(proposal_boxes), losses
def roi_heads(self, image, features, proposals, targets):
image_shape2d = tf.shape(image)[2:] # h,w
assert len(features) == 5, "Features have to be P23456!"
gt_boxes, gt_labels, *_ = targets
if self.training:
proposals = sample_fast_rcnn_targets(proposals.boxes, gt_boxes, gt_labels)
fastrcnn_head_func = getattr(model_frcnn, cfg.FPN.FRCNN_HEAD_FUNC) fastrcnn_head_func = getattr(model_frcnn, cfg.FPN.FRCNN_HEAD_FUNC)
if not cfg.FPN.CASCADE: if not cfg.FPN.CASCADE:
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], proposals.boxes, 7) roi_feature_fastrcnn = multilevel_roi_align(features[:4], proposals.boxes, 7)
head_feature = fastrcnn_head_func('fastrcnn', roi_feature_fastrcnn) head_feature = fastrcnn_head_func('fastrcnn', roi_feature_fastrcnn)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs( fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
...@@ -272,42 +302,32 @@ class ResNetFPNModel(DetectionModel): ...@@ -272,42 +302,32 @@ class ResNetFPNModel(DetectionModel):
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32)) tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
else: else:
def roi_func(boxes): def roi_func(boxes):
return multilevel_roi_align(p23456[:4], boxes, 7) return multilevel_roi_align(features[:4], boxes, 7)
fastrcnn_head = CascadeRCNNHead( fastrcnn_head = CascadeRCNNHead(
proposals, roi_func, fastrcnn_head_func, image_shape2d, cfg.DATA.NUM_CLASS) proposals, roi_func, fastrcnn_head_func, image_shape2d, cfg.DATA.NUM_CLASS)
if is_training: if self.training:
all_losses = [] all_losses = fastrcnn_head.losses()
all_losses.extend(multilevel_rpn_losses(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits))
all_losses.extend(fastrcnn_head.losses())
if cfg.MODE_MASK: if cfg.MODE_MASK:
gt_masks = targets[2]
# maskrcnn loss # maskrcnn loss
roi_feature_maskrcnn = multilevel_roi_align( roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], proposals.fg_boxes(), 14, features[:4], proposals.fg_boxes(), 14,
name_scope='multilevel_roi_align_mask') name_scope='multilevel_roi_align_mask')
maskrcnn_head_func = getattr(model_mrcnn, cfg.FPN.MRCNN_HEAD_FUNC) maskrcnn_head_func = getattr(model_mrcnn, cfg.FPN.MRCNN_HEAD_FUNC)
mask_logits = maskrcnn_head_func( mask_logits = maskrcnn_head_func(
'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY) # #fg x #cat x 28 x 28 'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY) # #fg x #cat x 28 x 28
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(inputs['gt_masks'], 1), tf.expand_dims(gt_masks, 1),
proposals.fg_boxes(), proposals.fg_boxes(),
proposals.fg_inds_wrt_gt, 28, proposals.fg_inds_wrt_gt, 28,
pad_border=False) # fg x 1x28x28 pad_border=False) # fg x 1x28x28
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
all_losses.append(maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg)) all_losses.append(maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg))
return all_losses
wd_cost = regularize_cost(
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
all_losses.append(wd_cost)
total_cost = tf.add_n(all_losses, 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
else: else:
decoded_boxes = fastrcnn_head.decoded_output_boxes() decoded_boxes = fastrcnn_head.decoded_output_boxes()
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes') decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
...@@ -316,13 +336,14 @@ class ResNetFPNModel(DetectionModel): ...@@ -316,13 +336,14 @@ class ResNetFPNModel(DetectionModel):
decoded_boxes, label_scores, name_scope='output') decoded_boxes, label_scores, name_scope='output')
if cfg.MODE_MASK: if cfg.MODE_MASK:
# Cascade inference needs roi transform with refined boxes. # Cascade inference needs roi transform with refined boxes.
roi_feature_maskrcnn = multilevel_roi_align(p23456[:4], final_boxes, 14) roi_feature_maskrcnn = multilevel_roi_align(features[:4], final_boxes, 14)
maskrcnn_head_func = getattr(model_mrcnn, cfg.FPN.MRCNN_HEAD_FUNC) maskrcnn_head_func = getattr(model_mrcnn, cfg.FPN.MRCNN_HEAD_FUNC)
mask_logits = maskrcnn_head_func( mask_logits = maskrcnn_head_func(
'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY) # #fg x #cat x 28 x 28 'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY) # #fg x #cat x 28 x 28
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28 final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28
tf.sigmoid(final_mask_logits, name='output/masks') tf.sigmoid(final_mask_logits, name='output/masks')
return []
def visualize(model, model_path, nr_visualize=100, output_dir='output'): def visualize(model, model_path, nr_visualize=100, output_dir='output'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment