Commit 5140610d authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] reorganize sample_fastrcnn_target

parent bda9e14e
......@@ -242,25 +242,6 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
target_boxes: tx4 encoded box, the regression target
labels: t labels
"""
@under_name_scope()
def assign_class_to_roi(iou, gt_boxes, gt_labels):
"""
Args:
iou: nxm (#proposal x #gt)
Returns:
fg_mask: n boolean, whether each roibox is fg
roi_labels: n int32, best label for each roi box
best_gt_boxes: nx4
"""
# find best gt box for each roi box
best_iou_ind = tf.argmax(iou, axis=1) # n, each in 1~m
best_iou = tf.reduce_max(iou, axis=1) # n,
best_gt_boxes = tf.gather(gt_boxes, best_iou_ind) # nx4
best_gt_labels = tf.gather(gt_labels, best_iou_ind) # n, each in 1~C
fg_mask = best_iou >= config.FASTRCNN_FG_THRESH
return fg_mask, best_gt_labels, best_gt_boxes
iou = pairwise_iou(boxes, gt_boxes) # nxm
proposal_metrics(iou)
......@@ -268,22 +249,30 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
boxes = tf.concat([boxes, gt_boxes], axis=0) # (n+m) x 4
iou = tf.concat([iou, tf.eye(tf.shape(gt_boxes)[0])], axis=0) # (n+m) x m
# n+m, n+m, (n+m)x4
fg_mask, roi_labels, best_gt_boxes = assign_class_to_roi(iou, gt_boxes, gt_labels)
# #proposal=n+m from now on
best_iou_ind = tf.argmax(iou, axis=1) # #proposal, each in 1~m
best_gt_boxes = tf.gather(gt_boxes, best_iou_ind) # #proposalx4
best_gt_labels = tf.gather(gt_labels, best_iou_ind) # #proposal, each in 1~C
def sample_fg_bg(iou):
fg_mask = tf.reduce_max(iou, axis=1) >= config.FASTRCNN_FG_THRESH
fg_inds = tf.where(fg_mask)[:, 0]
num_fg = tf.minimum(int(
config.FASTRCNN_BATCH_PER_IM * config.FASTRCNN_FG_RATIO),
tf.size(fg_inds), name='num_fg')
fg_inds = tf.random_shuffle(fg_inds)[:num_fg]
fg_inds = tf.where(fg_mask)[:, 0]
num_fg = tf.minimum(int(
config.FASTRCNN_BATCH_PER_IM * config.FASTRCNN_FG_RATIO),
tf.size(fg_inds), name='num_fg')
fg_inds = tf.random_shuffle(fg_inds)[:num_fg]
bg_inds = tf.where(tf.logical_not(fg_mask))[:, 0]
num_bg = tf.minimum(
config.FASTRCNN_BATCH_PER_IM - num_fg,
tf.size(bg_inds), name='num_bg')
bg_inds = tf.random_shuffle(bg_inds)[:num_bg]
bg_inds = tf.where(tf.logical_not(fg_mask))[:, 0]
num_bg = tf.minimum(
config.FASTRCNN_BATCH_PER_IM - num_fg,
tf.size(bg_inds), name='num_bg')
bg_inds = tf.random_shuffle(bg_inds)[:num_bg]
add_moving_summary(num_fg, num_bg)
return fg_inds, bg_inds
add_moving_summary(num_fg, num_bg)
fg_inds, bg_inds = sample_fg_bg(iou)
all_indices = tf.concat([fg_inds, bg_inds], axis=0) # ind in all n+m boxes
ret_boxes = tf.gather(boxes, all_indices, name='sampled_boxes')
......@@ -293,7 +282,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
# bg boxes will not be trained on
ret_labels = tf.concat(
[tf.gather(roi_labels, fg_inds),
[tf.gather(best_gt_labels, fg_inds),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
return ret_boxes, tf.stop_gradient(ret_encoded_boxes), tf.stop_gradient(ret_labels)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment