Commit f17d16da authored by Yuxin Wu's avatar Yuxin Wu

[Cascade R-CNN] only train with nonempty boxes

parent b6318616
...@@ -3,7 +3,7 @@ import tensorflow as tf ...@@ -3,7 +3,7 @@ import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context from tensorpack.tfutils import get_current_tower_context
from config import config as cfg from config import config as cfg
from utils.box_ops import pairwise_iou from utils.box_ops import pairwise_iou, area as tf_area
from .model_box import clip_boxes from .model_box import clip_boxes
from .model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs from .model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs
...@@ -27,8 +27,8 @@ class CascadeRCNNHead(object): ...@@ -27,8 +27,8 @@ class CascadeRCNNHead(object):
self.num_cascade_stages = len(cfg.CASCADE.IOUS) self.num_cascade_stages = len(cfg.CASCADE.IOUS)
self.is_training = get_current_tower_context().is_training self.training = get_current_tower_context().is_training
if self.is_training: if self.training:
@tf.custom_gradient @tf.custom_gradient
def scale_gradient(x): def scale_gradient(x):
return x, lambda dy: dy * (1.0 / self.num_cascade_stages) return x, lambda dy: dy * (1.0 / self.num_cascade_stages)
...@@ -72,6 +72,8 @@ class CascadeRCNNHead(object): ...@@ -72,6 +72,8 @@ class CascadeRCNNHead(object):
refined_boxes = head.decoded_output_boxes_class_agnostic() refined_boxes = head.decoded_output_boxes_class_agnostic()
refined_boxes = clip_boxes(refined_boxes, self.image_shape2d) refined_boxes = clip_boxes(refined_boxes, self.image_shape2d)
if self.training:
refined_boxes = tf.boolean_mask(refined_boxes, tf_area(refined_boxes) > 0)
return head, tf.stop_gradient(refined_boxes, name='output_boxes') return head, tf.stop_gradient(refined_boxes, name='output_boxes')
def match_box_with_gt(self, boxes, iou_threshold): def match_box_with_gt(self, boxes, iou_threshold):
...@@ -81,7 +83,7 @@ class CascadeRCNNHead(object): ...@@ -81,7 +83,7 @@ class CascadeRCNNHead(object):
Returns: Returns:
BoxProposals BoxProposals
""" """
if self.is_training: if self.training:
with tf.name_scope('match_box_with_gt_{}'.format(iou_threshold)): with tf.name_scope('match_box_with_gt_{}'.format(iou_threshold)):
iou = pairwise_iou(boxes, self.gt_boxes) # NxM iou = pairwise_iou(boxes, self.gt_boxes) # NxM
max_iou_per_box = tf.reduce_max(iou, axis=1) # N max_iou_per_box = tf.reduce_max(iou, axis=1) # N
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment