Commit e37e91e8 authored by Yuxin Wu's avatar Yuxin Wu

[FastRCNN] decode_bbox_target preserve shape

parent df869711
...@@ -41,7 +41,7 @@ TRAIN_POST_NMS_TOPK = 2000 ...@@ -41,7 +41,7 @@ TRAIN_POST_NMS_TOPK = 2000
CROWD_OVERLAP_THRES = 0.7 CROWD_OVERLAP_THRES = 0.7
# fastrcnn training --------------------- # fastrcnn training ---------------------
FASTRCNN_BATCH_PER_IM = 64 FASTRCNN_BATCH_PER_IM = 256
FASTRCNN_BBOX_REG_WEIGHTS = np.array([10, 10, 5, 5], dtype='float32') FASTRCNN_BBOX_REG_WEIGHTS = np.array([10, 10, 5, 5], dtype='float32')
FASTRCNN_FG_THRESH = 0.5 FASTRCNN_FG_THRESH = 0.5
# keep fg ratio in a batch in this range # keep fg ratio in a batch in this range
......
...@@ -53,7 +53,6 @@ def nms_fastrcnn_results(boxes, probs): ...@@ -53,7 +53,6 @@ def nms_fastrcnn_results(boxes, probs):
C = probs.shape[1] C = probs.shape[1]
boxes = boxes.copy() boxes = boxes.copy()
boxes_per_class = {}
nms_func = get_tf_nms(config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH) nms_func = get_tf_nms(config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH)
ret = [] ret = []
for klass in range(1, C): for klass in range(1, C):
......
...@@ -100,15 +100,16 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits): ...@@ -100,15 +100,16 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
def decode_bbox_target(box_predictions, anchors): def decode_bbox_target(box_predictions, anchors):
""" """
Args: Args:
box_predictions: fHxfWxNAx4, logits box_predictions: (..., 4), logits
anchors: fHxfWxNAx4, floatbox anchors: (..., 4), floatbox. Must have the same shape
Returns: Returns:
box_decoded: (fHxfWxNA)x4, float32 box_decoded: (..., 4), float32. With the same shape.
""" """
orig_shape = tf.shape(anchors)
box_pred_txtytwth = tf.reshape(box_predictions, (-1, 2, 2)) box_pred_txtytwth = tf.reshape(box_predictions, (-1, 2, 2))
box_pred_txty, box_pred_twth = tf.split(box_pred_txtytwth, 2, axis=1) box_pred_txty, box_pred_twth = tf.split(box_pred_txtytwth, 2, axis=1)
# each is (fHxfWxNA)x1x2 # each is (...)x1x2
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2)) anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1) anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
...@@ -119,9 +120,9 @@ def decode_bbox_target(box_predictions, anchors): ...@@ -119,9 +120,9 @@ def decode_bbox_target(box_predictions, anchors):
box_pred_twth, config.BBOX_DECODE_CLIP)) * waha box_pred_twth, config.BBOX_DECODE_CLIP)) * waha
xbyb = box_pred_txty * waha + xaya xbyb = box_pred_txty * waha + xaya
x1y1 = xbyb - wbhb * 0.5 x1y1 = xbyb - wbhb * 0.5
x2y2 = xbyb + wbhb * 0.5 x2y2 = xbyb + wbhb * 0.5 # (...)x1x2
out = tf.squeeze(tf.concat([x1y1, x2y2], axis=2), axis=1, name='output') out = tf.concat([x1y1, x2y2], axis=1)
return out return tf.reshape(out, orig_shape)
@under_name_scope() @under_name_scope()
......
...@@ -39,7 +39,7 @@ from viz import ( ...@@ -39,7 +39,7 @@ from viz import (
draw_predictions, draw_final_outputs) draw_predictions, draw_final_outputs)
from common import clip_boxes, CustomResize, print_config from common import clip_boxes, CustomResize, print_config
from eval import ( from eval import (
eval_on_dataflow, detect_one_image, print_evaluation_scores, get_tf_nms, eval_on_dataflow, detect_one_image, print_evaluation_scores,
nms_fastrcnn_results) nms_fastrcnn_results)
import config import config
...@@ -92,9 +92,9 @@ class Model(ModelDesc): ...@@ -92,9 +92,9 @@ class Model(ModelDesc):
rpn_label_loss, rpn_box_loss = rpn_losses( rpn_label_loss, rpn_box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)
decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # (fHxfWxNA)x4, floatbox decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox
proposal_boxes, proposal_scores = generate_rpn_proposals( proposal_boxes, proposal_scores = generate_rpn_proposals(
decoded_boxes, tf.reshape(decoded_boxes, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]), tf.reshape(rpn_label_logits, [-1]),
tf.shape(image)[2:]) tf.shape(image)[2:])
...@@ -127,14 +127,14 @@ class Model(ModelDesc): ...@@ -127,14 +127,14 @@ class Model(ModelDesc):
for k in self.cost, wd_cost: for k in self.cost, wd_cost:
add_moving_summary(k) add_moving_summary(k)
else: else:
label_probs = tf.nn.softmax(fastrcnn_label_logits, name='fastrcnn_all_probs') # NP, label_probs = tf.nn.softmax(fastrcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class
labels = tf.argmax(fastrcnn_label_logits, axis=1) labels = tf.argmax(fastrcnn_label_logits, axis=1)
fg_ind, fg_box_logits = fastrcnn_predict_boxes(labels, fastrcnn_box_logits) fg_ind, fg_box_logits = fastrcnn_predict_boxes(labels, fastrcnn_box_logits)
fg_label_probs = tf.gather(label_probs, fg_ind, name='fastrcnn_fg_probs') fg_label_probs = tf.gather(label_probs, fg_ind, name='fastrcnn_fg_probs')
fg_boxes = tf.gather(proposal_boxes, fg_ind) fg_boxes = tf.gather(proposal_boxes, fg_ind)
fg_box_logits = fg_box_logits / tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS) fg_box_logits = fg_box_logits / tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
decoded_boxes = decode_bbox_target(fg_box_logits, fg_boxes) # Nfx4, floatbox decoded_boxes = decode_bbox_target(fg_box_logits, fg_boxes) # #fgx4, floatbox
decoded_boxes = tf.identity(decoded_boxes, name='fastrcnn_fg_boxes') decoded_boxes = tf.identity(decoded_boxes, name='fastrcnn_fg_boxes')
def _get_optimizer(self): def _get_optimizer(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment