Commit ce84e1c9 authored by Yuxin Wu's avatar Yuxin Wu

fix loss scale

parent 1fc18a6e
...@@ -141,12 +141,12 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes): ...@@ -141,12 +141,12 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
overlap_with_crowd = cand_inds[ious.max(axis=1) > config.CROWD_OVERLAP_THRES] overlap_with_crowd = cand_inds[ious.max(axis=1) > config.CROWD_OVERLAP_THRES]
anchor_labels[overlap_with_crowd] = -1 anchor_labels[overlap_with_crowd] = -1
# Filter fg labels: ignore some fg if fg is too many # Subsample fg labels: ignore some fg if fg is too many
target_num_fg = int(config.RPN_BATCH_PER_IM * config.RPN_FG_RATIO) target_num_fg = int(config.RPN_BATCH_PER_IM * config.RPN_FG_RATIO)
fg_inds = filter_box_label(anchor_labels, 1, target_num_fg) fg_inds = filter_box_label(anchor_labels, 1, target_num_fg)
# Note that fg could be fewer than the target ratio # Note that fg could be fewer than the target ratio
# filter bg labels. num_bg is not allowed to be too many # Subsample bg labels. num_bg is not allowed to be too many
old_num_bg = np.sum(anchor_labels == 0) old_num_bg = np.sum(anchor_labels == 0)
if old_num_bg == 0 or len(fg_inds) == 0: if old_num_bg == 0 or len(fg_inds) == 0:
# No valid bg/fg in this image, skip. # No valid bg/fg in this image, skip.
......
...@@ -99,7 +99,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits): ...@@ -99,7 +99,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
placeholder = 1. placeholder = 1.
label_loss = tf.nn.sigmoid_cross_entropy_with_logits( label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits) labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits)
label_loss = tf.reduce_mean(label_loss) label_loss = label_loss * (1. / config.RPN_BATCH_PER_IM)
label_loss = tf.where(tf.equal(nr_valid, 0), placeholder, label_loss, name='label_loss') label_loss = tf.where(tf.equal(nr_valid, 0), placeholder, label_loss, name='label_loss')
pos_anchor_boxes = tf.boolean_mask(anchor_boxes, pos_mask) pos_anchor_boxes = tf.boolean_mask(anchor_boxes, pos_mask)
...@@ -108,9 +108,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits): ...@@ -108,9 +108,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
box_loss = tf.losses.huber_loss( box_loss = tf.losses.huber_loss(
pos_anchor_boxes, pos_box_logits, delta=delta, pos_anchor_boxes, pos_box_logits, delta=delta,
reduction=tf.losses.Reduction.SUM) / delta reduction=tf.losses.Reduction.SUM) / delta
box_loss = tf.div( box_loss = box_loss * (1. / config.RPN_BATCH_PER_IM)
box_loss,
tf.cast(nr_valid, tf.float32))
box_loss = tf.where(tf.equal(nr_pos, 0), placeholder, box_loss, name='box_loss') box_loss = tf.where(tf.equal(nr_pos, 0), placeholder, box_loss, name='box_loss')
add_moving_summary(label_loss, box_loss, nr_valid, nr_pos) add_moving_summary(label_loss, box_loss, nr_valid, nr_pos)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment