Commit a042f821 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] reorgnaize sample_fast_rcnn_targets

parent acfb57c2
......@@ -209,6 +209,26 @@ def generate_rpn_proposals(boxes, scores, img_shape):
return final_boxes, final_scores
@under_name_scope()
def proposal_metrics(iou):
"""
Args:
iou: nxm, #proposal x #gt
"""
# find best roi for each gt, for summary only
best_iou = tf.reduce_max(iou, axis=0)
mean_best_iou = tf.reduce_mean(best_iou, name='best_iou_per_gt')
summaries = [mean_best_iou]
with tf.device('/cpu:0'):
for th in [0.3, 0.5]:
recall = tf.truediv(
tf.count_nonzero(best_iou >= th),
tf.size(best_iou, out_type=tf.int64),
name='recall_iou{}'.format(th))
summaries.append(recall)
add_moving_summary(*summaries)
@under_name_scope()
def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
"""
......@@ -226,7 +246,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
def assign_class_to_roi(iou, gt_boxes, gt_labels):
"""
Args:
iou: nxm (nr_proposal x nr_gt)
iou: nxm (#proposal x #gt)
Returns:
fg_mask: n boolean, whether each roibox is fg
roi_labels: n int32, best label for each roi box
......@@ -242,55 +262,38 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
return fg_mask, best_gt_labels, best_gt_boxes
iou = pairwise_iou(boxes, gt_boxes) # nxm
proposal_metrics(iou)
with tf.name_scope('proposal_metrics'):
# find best roi for each gt, for summary only
best_iou = tf.reduce_max(iou, axis=0)
mean_best_iou = tf.reduce_mean(best_iou, name='best_iou_per_gt')
summaries = [mean_best_iou]
with tf.device('/cpu:0'):
for th in [0.3, 0.5]:
recall = tf.truediv(
tf.count_nonzero(best_iou >= th),
tf.size(best_iou, out_type=tf.int64),
name='recall_iou{}'.format(th))
summaries.append(recall)
add_moving_summary(*summaries)
# add ground truth as proposals as well
boxes = tf.concat([boxes, gt_boxes], axis=0) # (n+m) x 4
iou = tf.concat([iou, tf.eye(tf.shape(gt_boxes)[0])], axis=0) # (n+m) x m
# n, n, nx4
# n+m, n+m, (n+m)x4
fg_mask, roi_labels, best_gt_boxes = assign_class_to_roi(iou, gt_boxes, gt_labels)
# don't have to add gt for training, but add it anyway
fg_inds = tf.reshape(tf.where(fg_mask), [-1])
fg_inds = tf.concat([fg_inds, tf.cast(
tf.range(tf.size(gt_labels)) + tf.shape(boxes)[0],
tf.int64)], 0)
num_fg = tf.size(fg_inds)
fg_inds = tf.where(fg_mask)[:, 0]
num_fg = tf.minimum(int(
config.FASTRCNN_BATCH_PER_IM * config.FASTRCNN_FG_RATIO),
num_fg, name='num_fg')
fg_inds = tf.slice(tf.random_shuffle(fg_inds), [0], [num_fg])
tf.size(fg_inds), name='num_fg')
fg_inds = tf.random_shuffle(fg_inds)[:num_fg]
bg_inds = tf.where(tf.logical_not(fg_mask))[:, 0]
num_bg = tf.size(bg_inds)
num_bg = tf.minimum(config.FASTRCNN_BATCH_PER_IM - num_fg, num_bg, name='num_bg')
bg_inds = tf.slice(tf.random_shuffle(bg_inds), [0], [num_bg])
num_bg = tf.minimum(
config.FASTRCNN_BATCH_PER_IM - num_fg,
tf.size(bg_inds), name='num_bg')
bg_inds = tf.random_shuffle(bg_inds)[:num_bg]
add_moving_summary(num_fg, num_bg)
all_boxes = tf.concat([boxes, gt_boxes], axis=0)
all_matched_gt_boxes = tf.concat([best_gt_boxes, gt_boxes], axis=0)
all_labels = tf.concat([roi_labels, gt_labels], axis=0)
ind_in_all = tf.concat([fg_inds, bg_inds], axis=0) # ind in all n+m boxes
ret_boxes = tf.gather(all_boxes, ind_in_all, name='sampled_boxes')
ret_matched_gt_boxes = tf.gather(all_matched_gt_boxes, ind_in_all)
all_indices = tf.concat([fg_inds, bg_inds], axis=0) # ind in all n+m boxes
ret_boxes = tf.gather(boxes, all_indices, name='sampled_boxes')
ret_matched_gt_boxes = tf.gather(best_gt_boxes, all_indices)
ret_encoded_boxes = encode_bbox_target(ret_matched_gt_boxes, ret_boxes)
ret_encoded_boxes = ret_encoded_boxes * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
# bg boxes will not be trained on
ret_labels = tf.concat(
[tf.gather(all_labels, fg_inds),
[tf.gather(roi_labels, fg_inds),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
return ret_boxes, tf.stop_gradient(ret_encoded_boxes), tf.stop_gradient(ret_labels)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment