Commit 438aef79 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] handle empty forground in frcnn head.

parent 6041a1a4
......@@ -144,6 +144,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
fg_inds = tf.where(labels > 0)[:, 0]
fg_labels = tf.gather(labels, fg_inds)
num_fg = tf.size(fg_inds, out_type=tf.int64)
empty_fg = tf.equal(num_fg, 0)
if int(fg_box_logits.shape[1]) > 1:
indices = tf.stack(
[tf.range(num_fg), fg_labels], axis=1) # #fgx2
......@@ -157,16 +158,18 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
accuracy = tf.reduce_mean(correct, name='accuracy')
fg_label_pred = tf.argmax(tf.gather(label_logits, fg_inds), axis=1)
num_zero = tf.reduce_sum(tf.to_int64(tf.equal(fg_label_pred, 0)), name='num_zero')
false_negative = tf.truediv(num_zero, num_fg, name='false_negative')
fg_accuracy = tf.reduce_mean(
tf.gather(correct, fg_inds), name='fg_accuracy')
false_negative = tf.where(
empty_fg, 0., tf.truediv(num_zero, num_fg), name='false_negative')
fg_accuracy = tf.where(
empty_fg, 0., tf.reduce_mean(tf.gather(correct, fg_inds)), name='fg_accuracy')
box_loss = tf.losses.huber_loss(
fg_boxes, fg_box_logits, reduction=tf.losses.Reduction.SUM)
box_loss = tf.truediv(
box_loss, tf.to_float(tf.shape(labels)[0]), name='box_loss')
add_moving_summary(label_loss, box_loss, accuracy, fg_accuracy, false_negative)
add_moving_summary(label_loss, box_loss, accuracy,
fg_accuracy, false_negative, tf.to_float(num_fg, name='num_fg_label'))
return label_loss, box_loss
......
......@@ -141,12 +141,13 @@ class ResNetC4Model(DetectionModel):
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
if is_training:
all_losses = []
# rpn loss
rpn_label_loss, rpn_box_loss = rpn_losses(
anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits)
all_losses.extend(rpn_losses(
anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits))
# fastrcnn loss
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_head.losses()
all_losses.extend(fastrcnn_head.losses())
if cfg.MODE_MASK:
# maskrcnn loss
......@@ -161,18 +162,13 @@ class ResNetC4Model(DetectionModel):
proposals.fg_inds_wrt_gt, 14,
pad_border=False) # nfg x 1x14x14
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg)
else:
mrcnn_loss = 0.0
all_losses.append(maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg))
wd_cost = regularize_cost(
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
all_losses.append(wd_cost)
total_cost = tf.add_n([
rpn_label_loss, rpn_box_loss,
fastrcnn_label_loss, fastrcnn_box_loss,
mrcnn_loss, wd_cost], 'total_cost')
total_cost = tf.add_n(all_losses, 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
......@@ -272,11 +268,11 @@ class ResNetFPNModel(DetectionModel):
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
if is_training:
# rpn loss:
rpn_label_loss, rpn_box_loss = multilevel_rpn_losses(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits)
all_losses = []
all_losses.extend(multilevel_rpn_losses(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits))
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_head.losses()
all_losses.extend(fastrcnn_head.losses())
if cfg.MODE_MASK:
# maskrcnn loss
......@@ -293,17 +289,13 @@ class ResNetFPNModel(DetectionModel):
proposals.fg_inds_wrt_gt, 28,
pad_border=False) # fg x 1x28x28
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg)
else:
mrcnn_loss = 0.0
all_losses.append(maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg))
wd_cost = regularize_cost(
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
all_losses.append(wd_cost)
total_cost = tf.add_n([rpn_label_loss, rpn_box_loss,
fastrcnn_label_loss, fastrcnn_box_loss,
mrcnn_loss, wd_cost], 'total_cost')
total_cost = tf.add_n(all_losses, 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
......
......@@ -141,7 +141,7 @@ class CollectionGuard(object):
size_change.append((self._key_name(k), len(old_v), len(v)))
if newly_created:
logger.info(
"New collections created in {}: {}".format(
"New collections created in tower {}: {}".format(
self._name, ', '.join(newly_created)))
if size_change:
logger.info(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment