Commit 5140610d authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] reorganize sample_fastrcnn_target

parent bda9e14e
...@@ -242,25 +242,6 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -242,25 +242,6 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
target_boxes: tx4 encoded box, the regression target target_boxes: tx4 encoded box, the regression target
labels: t labels labels: t labels
""" """
@under_name_scope()
def assign_class_to_roi(iou, gt_boxes, gt_labels):
"""
Args:
iou: nxm (#proposal x #gt)
Returns:
fg_mask: n boolean, whether each roibox is fg
roi_labels: n int32, best label for each roi box
best_gt_boxes: nx4
"""
# find best gt box for each roi box
best_iou_ind = tf.argmax(iou, axis=1) # n, each in 1~m
best_iou = tf.reduce_max(iou, axis=1) # n,
best_gt_boxes = tf.gather(gt_boxes, best_iou_ind) # nx4
best_gt_labels = tf.gather(gt_labels, best_iou_ind) # n, each in 1~C
fg_mask = best_iou >= config.FASTRCNN_FG_THRESH
return fg_mask, best_gt_labels, best_gt_boxes
iou = pairwise_iou(boxes, gt_boxes) # nxm iou = pairwise_iou(boxes, gt_boxes) # nxm
proposal_metrics(iou) proposal_metrics(iou)
...@@ -268,8 +249,13 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -268,8 +249,13 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
boxes = tf.concat([boxes, gt_boxes], axis=0) # (n+m) x 4 boxes = tf.concat([boxes, gt_boxes], axis=0) # (n+m) x 4
iou = tf.concat([iou, tf.eye(tf.shape(gt_boxes)[0])], axis=0) # (n+m) x m iou = tf.concat([iou, tf.eye(tf.shape(gt_boxes)[0])], axis=0) # (n+m) x m
# n+m, n+m, (n+m)x4 # #proposal=n+m from now on
fg_mask, roi_labels, best_gt_boxes = assign_class_to_roi(iou, gt_boxes, gt_labels) best_iou_ind = tf.argmax(iou, axis=1) # #proposal, each in 1~m
best_gt_boxes = tf.gather(gt_boxes, best_iou_ind) # #proposalx4
best_gt_labels = tf.gather(gt_labels, best_iou_ind) # #proposal, each in 1~C
def sample_fg_bg(iou):
fg_mask = tf.reduce_max(iou, axis=1) >= config.FASTRCNN_FG_THRESH
fg_inds = tf.where(fg_mask)[:, 0] fg_inds = tf.where(fg_mask)[:, 0]
num_fg = tf.minimum(int( num_fg = tf.minimum(int(
...@@ -284,6 +270,9 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -284,6 +270,9 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
bg_inds = tf.random_shuffle(bg_inds)[:num_bg] bg_inds = tf.random_shuffle(bg_inds)[:num_bg]
add_moving_summary(num_fg, num_bg) add_moving_summary(num_fg, num_bg)
return fg_inds, bg_inds
fg_inds, bg_inds = sample_fg_bg(iou)
all_indices = tf.concat([fg_inds, bg_inds], axis=0) # ind in all n+m boxes all_indices = tf.concat([fg_inds, bg_inds], axis=0) # ind in all n+m boxes
ret_boxes = tf.gather(boxes, all_indices, name='sampled_boxes') ret_boxes = tf.gather(boxes, all_indices, name='sampled_boxes')
...@@ -293,7 +282,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -293,7 +282,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
# bg boxes will not be trained on # bg boxes will not be trained on
ret_labels = tf.concat( ret_labels = tf.concat(
[tf.gather(roi_labels, fg_inds), [tf.gather(best_gt_labels, fg_inds),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels') tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
return ret_boxes, tf.stop_gradient(ret_encoded_boxes), tf.stop_gradient(ret_labels) return ret_boxes, tf.stop_gradient(ret_encoded_boxes), tf.stop_gradient(ret_labels)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment