Commit e1e2ee8d authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] rcnn loss takes fg boxes only

parent e37e91e8
...@@ -129,11 +129,11 @@ def decode_bbox_target(box_predictions, anchors): ...@@ -129,11 +129,11 @@ def decode_bbox_target(box_predictions, anchors):
def encode_bbox_target(boxes, anchors): def encode_bbox_target(boxes, anchors):
""" """
Args: Args:
boxes: fHxfWxNAx4, float32 boxes: (..., 4), float32
anchors: fHxfWxNAx4, float32 anchors: (..., 4), float32
Returns: Returns:
box_encoded: fHxfWxNAx4 box_encoded: (..., 4), float32 with the same shape.
""" """
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2)) anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1) anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
...@@ -249,11 +249,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -249,11 +249,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
# add ground truth as proposals as well # add ground truth as proposals as well
boxes = tf.concat([boxes, gt_boxes], axis=0) # (n+m) x 4 boxes = tf.concat([boxes, gt_boxes], axis=0) # (n+m) x 4
iou = tf.concat([iou, tf.eye(tf.shape(gt_boxes)[0])], axis=0) # (n+m) x m iou = tf.concat([iou, tf.eye(tf.shape(gt_boxes)[0])], axis=0) # (n+m) x m
# #proposal=n+m from now on # #proposal=n+m from now on
best_iou_ind = tf.argmax(iou, axis=1) # #proposal, each in 1~m
best_gt_boxes = tf.gather(gt_boxes, best_iou_ind) # #proposalx4
best_gt_labels = tf.gather(gt_labels, best_iou_ind) # #proposal, each in 1~C
def sample_fg_bg(iou): def sample_fg_bg(iou):
fg_mask = tf.reduce_max(iou, axis=1) >= config.FASTRCNN_FG_THRESH fg_mask = tf.reduce_max(iou, axis=1) >= config.FASTRCNN_FG_THRESH
...@@ -274,18 +270,18 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -274,18 +270,18 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
return fg_inds, bg_inds return fg_inds, bg_inds
fg_inds, bg_inds = sample_fg_bg(iou) fg_inds, bg_inds = sample_fg_bg(iou)
# fg,bg indices w.r.t proposals
best_iou_ind = tf.argmax(iou, axis=1) # #proposal, each in 0~m-1
fg_inds_wrt_gt = tf.gather(best_iou_ind, fg_inds) # num_fg
all_indices = tf.concat([fg_inds, bg_inds], axis=0) # ind in all n+m boxes all_indices = tf.concat([fg_inds, bg_inds], axis=0) # indices w.r.t all n+m proposal boxes
ret_boxes = tf.gather(boxes, all_indices, name='sampled_boxes') ret_boxes = tf.gather(boxes, all_indices, name='sampled_proposal_boxes')
ret_matched_gt_boxes = tf.gather(best_gt_boxes, all_indices)
ret_encoded_boxes = encode_bbox_target(ret_matched_gt_boxes, ret_boxes)
ret_encoded_boxes = ret_encoded_boxes * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
# bg boxes will not be trained on
ret_labels = tf.concat( ret_labels = tf.concat(
[tf.gather(best_gt_labels, fg_inds), [tf.gather(gt_labels, fg_inds_wrt_gt),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels') tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
return ret_boxes, tf.stop_gradient(ret_encoded_boxes), tf.stop_gradient(ret_labels) return ret_boxes, tf.stop_gradient(ret_labels), fg_inds_wrt_gt
@under_name_scope() @under_name_scope()
...@@ -408,13 +404,13 @@ def fastrcnn_predict_boxes(labels, box_logits): ...@@ -408,13 +404,13 @@ def fastrcnn_predict_boxes(labels, box_logits):
@under_name_scope() @under_name_scope()
def fastrcnn_losses(labels, boxes, label_logits, box_logits): def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
""" """
Args: Args:
labels: n, labels: n,
boxes: nx4, encoded
label_logits: nxC label_logits: nxC
box_logits: nx(C-1)x4 fg_boxes: nfgx4, encoded
fg_box_logits: nfgx(C-1)x4
""" """
label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=label_logits) labels=labels, logits=label_logits)
...@@ -423,15 +419,19 @@ def fastrcnn_losses(labels, boxes, label_logits, box_logits): ...@@ -423,15 +419,19 @@ def fastrcnn_losses(labels, boxes, label_logits, box_logits):
correct = tf.to_float(tf.equal(prediction, labels)) # boolean/integer gather is unavailable on GPU correct = tf.to_float(tf.equal(prediction, labels)) # boolean/integer gather is unavailable on GPU
accuracy = tf.reduce_mean(correct, name='accuracy') accuracy = tf.reduce_mean(correct, name='accuracy')
# n x c-1 x 4 -> nfg x 4 fg_inds = tf.where(labels > 0)[:, 0]
fg_ind, fg_box_logits = fastrcnn_predict_boxes(labels, box_logits) fg_labels = tf.gather(labels, fg_inds)
fg_boxes = tf.gather(boxes, fg_ind) # nfgx4 num_fg = tf.size(fg_inds)
indices = tf.stack(
fg_label_pred = tf.argmax(tf.gather(label_logits, fg_ind), axis=1) [tf.range(num_fg),
num_zero = tf.reduce_sum(tf.cast(tf.equal(fg_label_pred, 0), tf.int32), name='num_zero') tf.to_int32(fg_labels) - 1], axis=1) # #fgx2
false_negative = tf.truediv(num_zero, tf.size(fg_ind), name='false_negative') fg_box_logits = tf.gather_nd(fg_box_logits, indices)
fg_correct = tf.gather(correct, fg_ind)
fg_accuracy = tf.reduce_mean(fg_correct, name='fg_accuracy') fg_label_pred = tf.argmax(tf.gather(label_logits, fg_inds), axis=1)
num_zero = tf.reduce_sum(tf.to_int32(tf.equal(fg_label_pred, 0)), name='num_zero')
false_negative = tf.truediv(num_zero, num_fg, name='false_negative')
fg_accuracy = tf.reduce_mean(
tf.gather(correct, fg_inds), name='fg_accuracy')
box_loss = tf.losses.huber_loss( box_loss = tf.losses.huber_loss(
fg_boxes, fg_box_logits, reduction=tf.losses.Reduction.SUM) fg_boxes, fg_box_logits, reduction=tf.losses.Reduction.SUM)
......
...@@ -57,7 +57,7 @@ class Model(ModelDesc): ...@@ -57,7 +57,7 @@ class Model(ModelDesc):
InputDesc(tf.int32, (None, None, config.NUM_ANCHOR), 'anchor_labels'), InputDesc(tf.int32, (None, None, config.NUM_ANCHOR), 'anchor_labels'),
InputDesc(tf.float32, (None, None, config.NUM_ANCHOR, 4), 'anchor_boxes'), InputDesc(tf.float32, (None, None, config.NUM_ANCHOR, 4), 'anchor_boxes'),
InputDesc(tf.float32, (None, 4), 'gt_boxes'), InputDesc(tf.float32, (None, 4), 'gt_boxes'),
InputDesc(tf.int64, (None,), 'gt_labels'), InputDesc(tf.int64, (None,), 'gt_labels'), # all > 0
] ]
def _preprocess(self, image): def _preprocess(self, image):
...@@ -100,7 +100,7 @@ class Model(ModelDesc): ...@@ -100,7 +100,7 @@ class Model(ModelDesc):
if is_training: if is_training:
# sample proposal boxes in training # sample proposal boxes in training
rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets( rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels) proposal_boxes, gt_boxes, gt_labels)
boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE)
else: else:
...@@ -112,8 +112,17 @@ class Model(ModelDesc): ...@@ -112,8 +112,17 @@ class Model(ModelDesc):
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS)
if is_training: if is_training:
fg_inds_wrt_sample = tf.where(rcnn_labels > 0)[:, 0] # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
encoded_boxes = encode_bbox_target(
matched_gt_boxes,
fg_sampled_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits, fastrcnn_box_logits) rcnn_labels, fastrcnn_label_logits,
encoded_boxes,
tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample))
wd_cost = regularize_cost( wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn)/.*W', '(?:group1|group2|group3|rpn|fastrcnn)/.*W',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment