Commit 7ca798da authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] clip final boxes in the graph

parent d8da92d6
...@@ -12,7 +12,7 @@ from pycocotools.coco import COCO ...@@ -12,7 +12,7 @@ from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
from coco import COCOMeta from coco import COCOMeta
from common import clip_boxes, CustomResize from common import CustomResize
import config import config
DetectionResult = namedtuple( DetectionResult = namedtuple(
...@@ -43,7 +43,6 @@ def detect_one_image(img, model_func): ...@@ -43,7 +43,6 @@ def detect_one_image(img, model_func):
scale = (resized_img.shape[0] * 1.0 / img.shape[0] + resized_img.shape[1] * 1.0 / img.shape[1]) / 2 scale = (resized_img.shape[0] * 1.0 / img.shape[0] + resized_img.shape[1] * 1.0 / img.shape[1]) / 2
boxes, probs, labels = model_func(resized_img) boxes, probs, labels = model_func(resized_img)
boxes = boxes / scale boxes = boxes / scale
boxes = clip_boxes(boxes, img.shape[:2])
results = [DetectionResult(*args) for args in zip(labels, boxes, probs)] results = [DetectionResult(*args) for args in zip(labels, boxes, probs)]
return results return results
......
...@@ -14,6 +14,19 @@ from utils.box_ops import pairwise_iou ...@@ -14,6 +14,19 @@ from utils.box_ops import pairwise_iou
import config import config
@under_name_scope()
def clip_boxes(boxes, window, name=None):
"""
Args:
boxes: nx4, xyxy
window: [h, w]
"""
boxes = tf.maximum(boxes, 0.0)
m = tf.tile(tf.reverse(window, [0]), [2]) # (4,)
boxes = tf.minimum(boxes, tf.to_float(m), name=name)
return boxes
@layer_register(log_shape=True) @layer_register(log_shape=True)
def rpn_head(featuremap, channel, num_anchors): def rpn_head(featuremap, channel, num_anchors):
""" """
...@@ -171,13 +184,6 @@ def generate_rpn_proposals(boxes, scores, img_shape): ...@@ -171,13 +184,6 @@ def generate_rpn_proposals(boxes, scores, img_shape):
PRE_NMS_TOPK = config.TEST_PRE_NMS_TOPK PRE_NMS_TOPK = config.TEST_PRE_NMS_TOPK
POST_NMS_TOPK = config.TEST_POST_NMS_TOPK POST_NMS_TOPK = config.TEST_POST_NMS_TOPK
@under_name_scope()
def clip_boxes(boxes, window):
boxes = tf.maximum(boxes, 0.0)
m = tf.tile(tf.reverse(window, [0]), [2]) # (4,)
boxes = tf.minimum(boxes, tf.to_float(m))
return boxes
topk = tf.minimum(PRE_NMS_TOPK, tf.size(scores)) topk = tf.minimum(PRE_NMS_TOPK, tf.size(scores))
topk_scores, topk_indices = tf.nn.top_k(scores, k=topk, sorted=False) topk_scores, topk_indices = tf.nn.top_k(scores, k=topk, sorted=False)
topk_boxes = tf.gather(boxes, topk_indices) topk_boxes = tf.gather(boxes, topk_indices)
......
...@@ -25,10 +25,10 @@ from coco import COCODetection ...@@ -25,10 +25,10 @@ from coco import COCODetection
from basemodel import ( from basemodel import (
image_preprocess, pretrained_resnet_conv4, resnet_conv5) image_preprocess, pretrained_resnet_conv4, resnet_conv5)
from model import ( from model import (
decode_bbox_target, encode_bbox_target, clip_boxes, decode_bbox_target, encode_bbox_target,
rpn_head, rpn_losses, rpn_head, rpn_losses,
generate_rpn_proposals, sample_fast_rcnn_targets, generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
roi_align, fastrcnn_head, fastrcnn_losses, fastrcnn_predictions) fastrcnn_head, fastrcnn_losses, fastrcnn_predictions)
from data import ( from data import (
get_train_dataflow, get_eval_dataflow, get_train_dataflow, get_eval_dataflow,
get_all_anchors) get_all_anchors)
...@@ -140,7 +140,7 @@ class Model(ModelDesc): ...@@ -140,7 +140,7 @@ class Model(ModelDesc):
decoded_boxes = decode_bbox_target( decoded_boxes = decode_bbox_target(
fastrcnn_box_logits / fastrcnn_box_logits /
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors) tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
decoded_boxes = tf.identity(decoded_boxes, name='fastrcnn_all_boxes') decoded_boxes = clip_boxes(decoded_boxes, tf.shape(image)[:2], name='fastrcnn_all_boxes')
# indices: Nx2. Each index into (#proposal, #category) # indices: Nx2. Each index into (#proposal, #category)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs) pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment