Commit 99ba595d authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use FastRCNNHead to manage head inputs&outputs

parent 02eb02e2
...@@ -8,9 +8,11 @@ from tensorpack.tfutils.argscope import argscope ...@@ -8,9 +8,11 @@ from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import ( from tensorpack.models import (
Conv2D, FullyConnected, layer_register) Conv2D, FullyConnected, layer_register)
from tensorpack.utils.argtools import memoized
from basemodel import GroupNorm from basemodel import GroupNorm
from utils.box_ops import pairwise_iou from utils.box_ops import pairwise_iou
from model_box import encode_bbox_target, decode_bbox_target
from config import config as cfg from config import config as cfg
...@@ -113,7 +115,7 @@ def fastrcnn_outputs(feature, num_classes): ...@@ -113,7 +115,7 @@ def fastrcnn_outputs(feature, num_classes):
box_regression = FullyConnected( box_regression = FullyConnected(
'box', feature, num_classes * 4, 'box', feature, num_classes * 4,
kernel_initializer=tf.random_normal_initializer(stddev=0.001)) kernel_initializer=tf.random_normal_initializer(stddev=0.001))
box_regression = tf.reshape(box_regression, (-1, num_classes, 4)) box_regression = tf.reshape(box_regression, (-1, num_classes, 4), name='output_box')
return classification, box_regression return classification, box_regression
...@@ -125,6 +127,9 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -125,6 +127,9 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
label_logits: nxC label_logits: nxC
fg_boxes: nfgx4, encoded fg_boxes: nfgx4, encoded
fg_box_logits: nfgxCx4 fg_box_logits: nfgxCx4
Returns:
label_loss, box_loss
""" """
label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=label_logits) labels=labels, logits=label_logits)
...@@ -132,10 +137,9 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -132,10 +137,9 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
fg_inds = tf.where(labels > 0)[:, 0] fg_inds = tf.where(labels > 0)[:, 0]
fg_labels = tf.gather(labels, fg_inds) fg_labels = tf.gather(labels, fg_inds)
num_fg = tf.size(fg_inds) num_fg = tf.size(fg_inds, out_type=tf.int64)
indices = tf.stack( indices = tf.stack(
[tf.range(num_fg), [tf.range(num_fg), fg_labels], axis=1) # #fgx2
tf.to_int32(fg_labels)], axis=1) # #fgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices) fg_box_logits = tf.gather_nd(fg_box_logits, indices)
with tf.name_scope('label_metrics'), tf.device('/cpu:0'): with tf.name_scope('label_metrics'), tf.device('/cpu:0'):
...@@ -143,7 +147,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -143,7 +147,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
correct = tf.to_float(tf.equal(prediction, labels)) # boolean/integer gather is unavailable on GPU correct = tf.to_float(tf.equal(prediction, labels)) # boolean/integer gather is unavailable on GPU
accuracy = tf.reduce_mean(correct, name='accuracy') accuracy = tf.reduce_mean(correct, name='accuracy')
fg_label_pred = tf.argmax(tf.gather(label_logits, fg_inds), axis=1) fg_label_pred = tf.argmax(tf.gather(label_logits, fg_inds), axis=1)
num_zero = tf.reduce_sum(tf.to_int32(tf.equal(fg_label_pred, 0)), name='num_zero') num_zero = tf.reduce_sum(tf.to_int64(tf.equal(fg_label_pred, 0)), name='num_zero')
false_negative = tf.truediv(num_zero, num_fg, name='false_negative') false_negative = tf.truediv(num_zero, num_fg, name='false_negative')
fg_accuracy = tf.reduce_mean( fg_accuracy = tf.reduce_mean(
tf.gather(correct, fg_inds), name='fg_accuracy') tf.gather(correct, fg_inds), name='fg_accuracy')
...@@ -163,12 +167,17 @@ def fastrcnn_predictions(boxes, probs): ...@@ -163,12 +167,17 @@ def fastrcnn_predictions(boxes, probs):
Generate final results from predictions of all proposals. Generate final results from predictions of all proposals.
Args: Args:
boxes: n#catx4 floatbox in float32 boxes: n#classx4 floatbox in float32
probs: nx#class probs: nx#class
Returns:
indices: Kx2. Each is (box_id, class_id)
probs: K floats
""" """
assert boxes.shape[1] == cfg.DATA.NUM_CATEGORY assert boxes.shape[1] == cfg.DATA.NUM_CLASS
assert probs.shape[1] == cfg.DATA.NUM_CLASS assert probs.shape[1] == cfg.DATA.NUM_CLASS
boxes = tf.transpose(boxes, [1, 0, 2]) # #catxnx4 boxes = tf.transpose(boxes, [1, 0, 2])[1:, :, :] # #catxnx4
boxes.set_shape([None, cfg.DATA.NUM_CATEGORY, None])
probs = tf.transpose(probs[:, 1:], [1, 0]) # #catxn probs = tf.transpose(probs[:, 1:], [1, 0]) # #catxn
def f(X): def f(X):
...@@ -209,8 +218,9 @@ def fastrcnn_predictions(boxes, probs): ...@@ -209,8 +218,9 @@ def fastrcnn_predictions(boxes, probs):
tf.minimum(cfg.TEST.RESULTS_PER_IM, tf.size(probs)), tf.minimum(cfg.TEST.RESULTS_PER_IM, tf.size(probs)),
sorted=False) sorted=False)
filtered_selection = tf.gather(selected_indices, topk_indices) filtered_selection = tf.gather(selected_indices, topk_indices)
filtered_selection = tf.reverse(filtered_selection, axis=[1], name='filtered_indices') cat_ids, box_ids = tf.unstack(filtered_selection, axis=1)
return filtered_selection, topk_probs final_ids = tf.stack([box_ids, cat_ids + 1], axis=1, name='final_ids') # Kx2, each is (box_id, class_id)
return final_ids, topk_probs
""" """
...@@ -267,3 +277,85 @@ def fastrcnn_4conv1fc_head(*args, **kwargs): ...@@ -267,3 +277,85 @@ def fastrcnn_4conv1fc_head(*args, **kwargs):
def fastrcnn_4conv1fc_gn_head(*args, **kwargs): def fastrcnn_4conv1fc_gn_head(*args, **kwargs):
return fastrcnn_Xconv1fc_head(*args, num_convs=4, norm='GN', **kwargs) return fastrcnn_Xconv1fc_head(*args, num_convs=4, norm='GN', **kwargs)
class FastRCNNHead(object):
"""
A class to process & decode inputs/outputs of a fastrcnn classification+regression head.
"""
def __init__(self, input_boxes, box_logits, label_logits, bbox_regression_weights,
labels=None, matched_gt_boxes_per_fg=None):
"""
Args:
input_boxes: Nx4, inputs to the head
box_logits: Nx#classx4, the output of the head
label_logits: Nx#class, the output of the head
bbox_regression_weights: a 4 element tensor
labels: N, each in [0, #class-1], the true label for each input box
matched_gt_boxes_per_fg: #fgx4, the matching gt boxes for each fg input box
The last two arguments could be None when not training.
"""
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
@memoized
def fg_inds_in_inputs(self):
""" Returns: #fg indices in [0, N-1] """
return tf.reshape(tf.where(self.labels > 0), [-1], name='fg_inds_in_inputs')
@memoized
def fg_input_boxes(self):
""" Returns: #fgx4 """
return tf.gather(self.input_boxes, self.fg_inds_in_inputs(), name='fg_input_boxes')
@memoized
def fg_box_logits(self):
""" Returns: #fg x #class x 4 """
return tf.gather(self.box_logits, self.fg_inds_in_inputs(), name='fg_box_logits')
@memoized
def fg_labels(self):
""" Returns: #fg """
return tf.gather(self.labels, self.fg_inds_in_inputs(), name='fg_labels')
@memoized
def losses(self):
encoded_fg_gt_boxes = encode_bbox_target(
self.matched_gt_boxes_per_fg,
self.fg_input_boxes()) * self.bbox_regression_weights
return fastrcnn_losses(
self.labels, self.label_logits,
encoded_fg_gt_boxes, self.fg_box_logits()
)
@memoized
def decoded_output_boxes(self):
""" Returns: N x #class x 4 """
anchors = tf.tile(tf.expand_dims(self.input_boxes, 1),
[1, cfg.DATA.NUM_CLASS, 1]) # N x #class x 4
decoded_boxes = decode_bbox_target(
self.box_logits / self.bbox_regression_weights,
anchors
)
return decoded_boxes
@memoized
def decoded_output_boxes_for_true_label(self):
""" Returns: Nx4 decoded boxes """
indices = tf.stack([
tf.range(tf.size(self.labels, out_type=tf.int64)),
self.labels
])
needed_logits = tf.gather_nd(self.box_logits, indices)
decoded = decode_bbox_target(
needed_logits / self.bbox_regression_weights,
self.input_boxes
)
return decoded
@memoized
def output_scores(self, name=None):
""" Returns: N x #class scores, summed to one for each box."""
return tf.nn.softmax(self.label_logits, name=name)
...@@ -35,15 +35,14 @@ import model_frcnn ...@@ -35,15 +35,14 @@ import model_frcnn
import model_mrcnn import model_mrcnn
from model_frcnn import ( from model_frcnn import (
sample_fast_rcnn_targets, sample_fast_rcnn_targets,
fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions) fastrcnn_outputs, fastrcnn_predictions, FastRCNNHead)
from model_mrcnn import maskrcnn_upXconv_head, maskrcnn_loss from model_mrcnn import maskrcnn_upXconv_head, maskrcnn_loss
from model_rpn import rpn_head, rpn_losses, generate_rpn_proposals from model_rpn import rpn_head, rpn_losses, generate_rpn_proposals
from model_fpn import ( from model_fpn import (
fpn_model, multilevel_roi_align, fpn_model, multilevel_roi_align,
multilevel_rpn_losses, generate_fpn_proposals) multilevel_rpn_losses, generate_fpn_proposals)
from model_box import ( from model_box import (
clip_boxes, decode_bbox_target, encode_bbox_target, clip_boxes, crop_and_resize, roi_align, RPNAnchors)
crop_and_resize, roi_align, RPNAnchors)
from data import ( from data import (
get_train_dataflow, get_eval_dataflow, get_train_dataflow, get_eval_dataflow,
...@@ -73,62 +72,25 @@ class DetectionModel(ModelDesc): ...@@ -73,62 +72,25 @@ class DetectionModel(ModelDesc):
opt = optimizer.AccumGradOptimizer(opt, 8 // cfg.TRAIN.NUM_GPUS) opt = optimizer.AccumGradOptimizer(opt, 8 // cfg.TRAIN.NUM_GPUS)
return opt return opt
def fastrcnn_training(self, image, def fastrcnn_inference(self, image_shape2d, fastrcnn_head):
rcnn_labels, fg_rcnn_boxes, gt_boxes_per_fg,
rcnn_label_logits, fg_rcnn_box_logits):
"""
Args:
image (NCHW):
rcnn_labels (n): labels for each sampled targets
fg_rcnn_boxes (fg x 4): proposal boxes for each sampled foreground targets
gt_boxes_per_fg (fg x 4): matching gt boxes for each sampled foreground targets
rcnn_label_logits (n): label logits for each sampled targets
fg_rcnn_box_logits (fg x #class x 4): box logits for each sampled foreground targets
"""
with tf.name_scope('fg_sample_patch_viz'):
fg_sampled_patches = crop_and_resize(
image, fg_rcnn_boxes,
tf.zeros([tf.shape(fg_rcnn_boxes)[0]], dtype=tf.int32), 300)
fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1])
fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB
tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
encoded_boxes = encode_bbox_target(
gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, rcnn_label_logits,
encoded_boxes,
fg_rcnn_box_logits)
return fastrcnn_label_loss, fastrcnn_box_loss
def fastrcnn_inference(self, image_shape2d,
rcnn_boxes, rcnn_label_logits, rcnn_box_logits):
""" """
Args: Args:
image_shape2d: h, w image_shape2d: h, w
rcnn_boxes (nx4): the proposal boxes fastrcnn_head (FastRCNNHead):
rcnn_label_logits (n):
rcnn_box_logits (nx #class x 4):
Returns: Returns:
boxes (mx4): boxes (mx4):
labels (m): each >= 1 labels (m): each >= 1
""" """
rcnn_box_logits = rcnn_box_logits[:, 1:, :] decoded_boxes = fastrcnn_head.decoded_output_boxes()
rcnn_box_logits.set_shape([None, cfg.DATA.NUM_CATEGORY, None])
label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class
anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, cfg.DATA.NUM_CATEGORY, 1]) # #proposal x #Cat x 4
decoded_boxes = decode_bbox_target(
rcnn_box_logits /
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32), anchors)
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes') decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
label_probs = fastrcnn_head.output_scores(name='fastrcnn_all_probs')
# indices: Nx2. Each index into (#proposal, #category) # indices: Nx2. Each index into (#box, #class)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs) pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
final_probs = tf.identity(final_probs, 'final_probs') final_probs = tf.identity(final_probs, 'final_probs')
final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes') final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels') final_labels = tf.gather(pred_indices, 1, axis=1, name='final_labels')
return final_boxes, final_labels return final_boxes, final_labels
def get_inference_tensor_names(self): def get_inference_tensor_names(self):
...@@ -184,10 +146,12 @@ class ResNetC4Model(DetectionModel): ...@@ -184,10 +146,12 @@ class ResNetC4Model(DetectionModel):
# sample proposal boxes in training # sample proposal boxes in training
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels) proposal_boxes, gt_boxes, gt_labels)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt, name='gt_boxes_per_fg_proposal')
else: else:
# The boxes to be used to crop RoIs. # The boxes to be used to crop RoIs.
# Use all proposal boxes in inference # Use all proposal boxes in inference
rcnn_boxes = proposal_boxes rcnn_boxes = proposal_boxes
rcnn_labels, matched_gt_boxes = None, None
boxes_on_featuremap = rcnn_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE) boxes_on_featuremap = rcnn_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
...@@ -197,37 +161,32 @@ class ResNetC4Model(DetectionModel): ...@@ -197,37 +161,32 @@ class ResNetC4Model(DetectionModel):
feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first') feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS),
rcnn_labels, matched_gt_boxes)
if is_training: if is_training:
# rpn loss # rpn loss
rpn_label_loss, rpn_box_loss = rpn_losses( rpn_label_loss, rpn_box_loss = rpn_losses(
anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits) anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits)
# fastrcnn loss # fastrcnn loss
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_head.losses()
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)
fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
if cfg.MODE_MASK: if cfg.MODE_MASK:
# maskrcnn loss # maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
# In training, mask branch shares the same C5 feature. # In training, mask branch shares the same C5 feature.
fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample) fg_feature = tf.gather(feature_fastrcnn, fastrcnn_head.fg_inds_in_inputs())
mask_logits = maskrcnn_upXconv_head( mask_logits = maskrcnn_upXconv_head(
'maskrcnn', fg_feature, cfg.DATA.NUM_CATEGORY, num_convs=0) # #fg x #cat x 14x14 'maskrcnn', fg_feature, cfg.DATA.NUM_CATEGORY, num_convs=0) # #fg x #cat x 14x14
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(inputs['gt_masks'], 1), tf.expand_dims(inputs['gt_masks'], 1),
fg_sampled_boxes, fastrcnn_head.fg_input_boxes(),
fg_inds_wrt_gt, 14, fg_inds_wrt_gt, 14,
pad_border=False) # nfg x 1x14x14 pad_border=False) # nfg x 1x14x14
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) mrcnn_loss = maskrcnn_loss(mask_logits, fastrcnn_head.fg_labels(), target_masks_for_fg)
else: else:
mrcnn_loss = 0.0 mrcnn_loss = 0.0
...@@ -242,8 +201,7 @@ class ResNetC4Model(DetectionModel): ...@@ -242,8 +201,7 @@ class ResNetC4Model(DetectionModel):
add_moving_summary(total_cost, wd_cost) add_moving_summary(total_cost, wd_cost)
return total_cost return total_cost
else: else:
final_boxes, final_labels = self.fastrcnn_inference( final_boxes, final_labels = self.fastrcnn_inference(image_shape2d, fastrcnn_head)
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if cfg.MODE_MASK: if cfg.MODE_MASK:
roi_resized = roi_align(featuremap, final_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE), 14) roi_resized = roi_align(featuremap, final_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE), 14)
...@@ -323,37 +281,32 @@ class ResNetFPNModel(DetectionModel): ...@@ -323,37 +281,32 @@ class ResNetFPNModel(DetectionModel):
if is_training: if is_training:
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels) proposal_boxes, gt_boxes, gt_labels)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
else: else:
# The boxes to be used to crop RoIs. # The boxes to be used to crop RoIs.
rcnn_boxes = proposal_boxes rcnn_boxes = proposal_boxes
rcnn_labels, matched_gt_boxes = None, None
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7) roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)
fastrcnn_head_func = getattr(model_frcnn, cfg.FPN.FRCNN_HEAD_FUNC) fastrcnn_head_func = getattr(model_frcnn, cfg.FPN.FRCNN_HEAD_FUNC)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func( fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func(
'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS) 'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS),
rcnn_labels, matched_gt_boxes)
if is_training: if is_training:
# rpn loss: # rpn loss:
rpn_label_loss, rpn_box_loss = multilevel_rpn_losses( rpn_label_loss, rpn_box_loss = multilevel_rpn_losses(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits) multilevel_anchors, multilevel_label_logits, multilevel_box_logits)
# fastrcnn loss: fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_head.losses()
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)
fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
if cfg.MODE_MASK: if cfg.MODE_MASK:
# maskrcnn loss # maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
roi_feature_maskrcnn = multilevel_roi_align( roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], fg_sampled_boxes, 14, p23456[:4], fastrcnn_head.fg_input_boxes(), 14,
name_scope='multilevel_roi_align_mask') name_scope='multilevel_roi_align_mask')
maskrcnn_head_func = getattr(model_mrcnn, cfg.FPN.MRCNN_HEAD_FUNC) maskrcnn_head_func = getattr(model_mrcnn, cfg.FPN.MRCNN_HEAD_FUNC)
mask_logits = maskrcnn_head_func( mask_logits = maskrcnn_head_func(
...@@ -361,11 +314,11 @@ class ResNetFPNModel(DetectionModel): ...@@ -361,11 +314,11 @@ class ResNetFPNModel(DetectionModel):
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(inputs['gt_masks'], 1), tf.expand_dims(inputs['gt_masks'], 1),
fg_sampled_boxes, fastrcnn_head.fg_input_boxes(),
fg_inds_wrt_gt, 28, fg_inds_wrt_gt, 28,
pad_border=False) # fg x 1x28x28 pad_border=False) # fg x 1x28x28
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) mrcnn_loss = maskrcnn_loss(mask_logits, fastrcnn_head.fg_labels(), target_masks_for_fg)
else: else:
mrcnn_loss = 0.0 mrcnn_loss = 0.0
...@@ -379,8 +332,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -379,8 +332,7 @@ class ResNetFPNModel(DetectionModel):
add_moving_summary(total_cost, wd_cost) add_moving_summary(total_cost, wd_cost)
return total_cost return total_cost
else: else:
final_boxes, final_labels = self.fastrcnn_inference( final_boxes, final_labels = self.fastrcnn_inference(image_shape2d, fastrcnn_head)
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if cfg.MODE_MASK: if cfg.MODE_MASK:
# Cascade inference needs roi transform with refined boxes. # Cascade inference needs roi transform with refined boxes.
roi_feature_maskrcnn = multilevel_roi_align(p23456[:4], final_boxes, 14) roi_feature_maskrcnn = multilevel_roi_align(p23456[:4], final_boxes, 14)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment