Commit 438aef79 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] handle empty forground in frcnn head.

parent 6041a1a4
...@@ -144,6 +144,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -144,6 +144,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
fg_inds = tf.where(labels > 0)[:, 0] fg_inds = tf.where(labels > 0)[:, 0]
fg_labels = tf.gather(labels, fg_inds) fg_labels = tf.gather(labels, fg_inds)
num_fg = tf.size(fg_inds, out_type=tf.int64) num_fg = tf.size(fg_inds, out_type=tf.int64)
empty_fg = tf.equal(num_fg, 0)
if int(fg_box_logits.shape[1]) > 1: if int(fg_box_logits.shape[1]) > 1:
indices = tf.stack( indices = tf.stack(
[tf.range(num_fg), fg_labels], axis=1) # #fgx2 [tf.range(num_fg), fg_labels], axis=1) # #fgx2
...@@ -157,16 +158,18 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -157,16 +158,18 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
accuracy = tf.reduce_mean(correct, name='accuracy') accuracy = tf.reduce_mean(correct, name='accuracy')
fg_label_pred = tf.argmax(tf.gather(label_logits, fg_inds), axis=1) fg_label_pred = tf.argmax(tf.gather(label_logits, fg_inds), axis=1)
num_zero = tf.reduce_sum(tf.to_int64(tf.equal(fg_label_pred, 0)), name='num_zero') num_zero = tf.reduce_sum(tf.to_int64(tf.equal(fg_label_pred, 0)), name='num_zero')
false_negative = tf.truediv(num_zero, num_fg, name='false_negative') false_negative = tf.where(
fg_accuracy = tf.reduce_mean( empty_fg, 0., tf.truediv(num_zero, num_fg), name='false_negative')
tf.gather(correct, fg_inds), name='fg_accuracy') fg_accuracy = tf.where(
empty_fg, 0., tf.reduce_mean(tf.gather(correct, fg_inds)), name='fg_accuracy')
box_loss = tf.losses.huber_loss( box_loss = tf.losses.huber_loss(
fg_boxes, fg_box_logits, reduction=tf.losses.Reduction.SUM) fg_boxes, fg_box_logits, reduction=tf.losses.Reduction.SUM)
box_loss = tf.truediv( box_loss = tf.truediv(
box_loss, tf.to_float(tf.shape(labels)[0]), name='box_loss') box_loss, tf.to_float(tf.shape(labels)[0]), name='box_loss')
add_moving_summary(label_loss, box_loss, accuracy, fg_accuracy, false_negative) add_moving_summary(label_loss, box_loss, accuracy,
fg_accuracy, false_negative, tf.to_float(num_fg, name='num_fg_label'))
return label_loss, box_loss return label_loss, box_loss
......
...@@ -141,12 +141,13 @@ class ResNetC4Model(DetectionModel): ...@@ -141,12 +141,13 @@ class ResNetC4Model(DetectionModel):
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32)) tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
if is_training: if is_training:
all_losses = []
# rpn loss # rpn loss
rpn_label_loss, rpn_box_loss = rpn_losses( all_losses.extend(rpn_losses(
anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits) anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits))
# fastrcnn loss # fastrcnn loss
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_head.losses() all_losses.extend(fastrcnn_head.losses())
if cfg.MODE_MASK: if cfg.MODE_MASK:
# maskrcnn loss # maskrcnn loss
...@@ -161,18 +162,13 @@ class ResNetC4Model(DetectionModel): ...@@ -161,18 +162,13 @@ class ResNetC4Model(DetectionModel):
proposals.fg_inds_wrt_gt, 14, proposals.fg_inds_wrt_gt, 14,
pad_border=False) # nfg x 1x14x14 pad_border=False) # nfg x 1x14x14
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg) all_losses.append(maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg))
else:
mrcnn_loss = 0.0
wd_cost = regularize_cost( wd_cost = regularize_cost(
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost') '.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
all_losses.append(wd_cost)
total_cost = tf.add_n([ total_cost = tf.add_n(all_losses, 'total_cost')
rpn_label_loss, rpn_box_loss,
fastrcnn_label_loss, fastrcnn_box_loss,
mrcnn_loss, wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost) add_moving_summary(total_cost, wd_cost)
return total_cost return total_cost
else: else:
...@@ -272,11 +268,11 @@ class ResNetFPNModel(DetectionModel): ...@@ -272,11 +268,11 @@ class ResNetFPNModel(DetectionModel):
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32)) tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
if is_training: if is_training:
# rpn loss: all_losses = []
rpn_label_loss, rpn_box_loss = multilevel_rpn_losses( all_losses.extend(multilevel_rpn_losses(
multilevel_anchors, multilevel_label_logits, multilevel_box_logits) multilevel_anchors, multilevel_label_logits, multilevel_box_logits))
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_head.losses() all_losses.extend(fastrcnn_head.losses())
if cfg.MODE_MASK: if cfg.MODE_MASK:
# maskrcnn loss # maskrcnn loss
...@@ -293,17 +289,13 @@ class ResNetFPNModel(DetectionModel): ...@@ -293,17 +289,13 @@ class ResNetFPNModel(DetectionModel):
proposals.fg_inds_wrt_gt, 28, proposals.fg_inds_wrt_gt, 28,
pad_border=False) # fg x 1x28x28 pad_border=False) # fg x 1x28x28
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg) all_losses.append(maskrcnn_loss(mask_logits, proposals.fg_labels(), target_masks_for_fg))
else:
mrcnn_loss = 0.0
wd_cost = regularize_cost( wd_cost = regularize_cost(
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost') '.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
all_losses.append(wd_cost)
total_cost = tf.add_n([rpn_label_loss, rpn_box_loss, total_cost = tf.add_n(all_losses, 'total_cost')
fastrcnn_label_loss, fastrcnn_box_loss,
mrcnn_loss, wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost) add_moving_summary(total_cost, wd_cost)
return total_cost return total_cost
else: else:
......
...@@ -141,7 +141,7 @@ class CollectionGuard(object): ...@@ -141,7 +141,7 @@ class CollectionGuard(object):
size_change.append((self._key_name(k), len(old_v), len(v))) size_change.append((self._key_name(k), len(old_v), len(v)))
if newly_created: if newly_created:
logger.info( logger.info(
"New collections created in {}: {}".format( "New collections created in tower {}: {}".format(
self._name, ', '.join(newly_created))) self._name, ', '.join(newly_created)))
if size_change: if size_change:
logger.info( logger.info(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment