Commit f17d16da authored by Yuxin Wu's avatar Yuxin Wu

[Cascade R-CNN] only train with nonempty boxes

parent b6318616
......@@ -3,7 +3,7 @@ import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context
from config import config as cfg
from utils.box_ops import pairwise_iou
from utils.box_ops import pairwise_iou, area as tf_area
from .model_box import clip_boxes
from .model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs
......@@ -27,8 +27,8 @@ class CascadeRCNNHead(object):
self.num_cascade_stages = len(cfg.CASCADE.IOUS)
self.is_training = get_current_tower_context().is_training
if self.is_training:
self.training = get_current_tower_context().is_training
if self.training:
@tf.custom_gradient
def scale_gradient(x):
return x, lambda dy: dy * (1.0 / self.num_cascade_stages)
......@@ -72,6 +72,8 @@ class CascadeRCNNHead(object):
refined_boxes = head.decoded_output_boxes_class_agnostic()
refined_boxes = clip_boxes(refined_boxes, self.image_shape2d)
if self.training:
refined_boxes = tf.boolean_mask(refined_boxes, tf_area(refined_boxes) > 0)
return head, tf.stop_gradient(refined_boxes, name='output_boxes')
def match_box_with_gt(self, boxes, iou_threshold):
......@@ -81,7 +83,7 @@ class CascadeRCNNHead(object):
Returns:
BoxProposals
"""
if self.is_training:
if self.training:
with tf.name_scope('match_box_with_gt_{}'.format(iou_threshold)):
iou = pairwise_iou(boxes, self.gt_boxes) # NxM
max_iou_per_box = tf.reduce_max(iou, axis=1) # N
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment