Commit 7ca798da authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] clip final boxes in the graph

parent d8da92d6
......@@ -12,7 +12,7 @@ from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from coco import COCOMeta
from common import clip_boxes, CustomResize
from common import CustomResize
import config
DetectionResult = namedtuple(
......@@ -43,7 +43,6 @@ def detect_one_image(img, model_func):
scale = (resized_img.shape[0] * 1.0 / img.shape[0] + resized_img.shape[1] * 1.0 / img.shape[1]) / 2
boxes, probs, labels = model_func(resized_img)
boxes = boxes / scale
boxes = clip_boxes(boxes, img.shape[:2])
results = [DetectionResult(*args) for args in zip(labels, boxes, probs)]
return results
......
......@@ -14,6 +14,19 @@ from utils.box_ops import pairwise_iou
import config
@under_name_scope()
def clip_boxes(boxes, window, name=None):
"""
Args:
boxes: nx4, xyxy
window: [h, w]
"""
boxes = tf.maximum(boxes, 0.0)
m = tf.tile(tf.reverse(window, [0]), [2]) # (4,)
boxes = tf.minimum(boxes, tf.to_float(m), name=name)
return boxes
@layer_register(log_shape=True)
def rpn_head(featuremap, channel, num_anchors):
"""
......@@ -171,13 +184,6 @@ def generate_rpn_proposals(boxes, scores, img_shape):
PRE_NMS_TOPK = config.TEST_PRE_NMS_TOPK
POST_NMS_TOPK = config.TEST_POST_NMS_TOPK
@under_name_scope()
def clip_boxes(boxes, window):
boxes = tf.maximum(boxes, 0.0)
m = tf.tile(tf.reverse(window, [0]), [2]) # (4,)
boxes = tf.minimum(boxes, tf.to_float(m))
return boxes
topk = tf.minimum(PRE_NMS_TOPK, tf.size(scores))
topk_scores, topk_indices = tf.nn.top_k(scores, k=topk, sorted=False)
topk_boxes = tf.gather(boxes, topk_indices)
......
......@@ -25,10 +25,10 @@ from coco import COCODetection
from basemodel import (
image_preprocess, pretrained_resnet_conv4, resnet_conv5)
from model import (
decode_bbox_target, encode_bbox_target,
clip_boxes, decode_bbox_target, encode_bbox_target,
rpn_head, rpn_losses,
generate_rpn_proposals, sample_fast_rcnn_targets,
roi_align, fastrcnn_head, fastrcnn_losses, fastrcnn_predictions)
generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
fastrcnn_head, fastrcnn_losses, fastrcnn_predictions)
from data import (
get_train_dataflow, get_eval_dataflow,
get_all_anchors)
......@@ -140,7 +140,7 @@ class Model(ModelDesc):
decoded_boxes = decode_bbox_target(
fastrcnn_box_logits /
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
decoded_boxes = tf.identity(decoded_boxes, name='fastrcnn_all_boxes')
decoded_boxes = clip_boxes(decoded_boxes, tf.shape(image)[:2], name='fastrcnn_all_boxes')
# indices: Nx2. Each index into (#proposal, #category)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment