Commit 33304beb authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] use DECODE_BBOX_CLIP constant instead of an argument

parent b3f6e31d
......@@ -25,6 +25,8 @@ ANCHOR_RATIOS = (0.5, 1., 2.)
NUM_ANCHOR = len(ANCHOR_SIZES) * len(ANCHOR_RATIOS)
POSITIVE_ANCHOR_THRES = 0.7
NEGATIVE_ANCHOR_THRES = 0.3
# just to avoid too large numbers.
BBOX_DECODE_CLIP = np.log(MAX_SIZE / 16.0)
# rpn training -------------------------
# keep fg ratio in a batch in this range
......
......@@ -96,12 +96,11 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
@under_name_scope()
def decode_bbox_target(box_predictions, anchors, stride):
def decode_bbox_target(box_predictions, anchors):
"""
Args:
box_predictions: fHxfWxNAx4, logits
anchors: fHxfWxNAx4, floatbox
stride (int): the stride of the anchors
Returns:
box_decoded: (fHxfWxNA)x4, float32
......@@ -116,7 +115,7 @@ def decode_bbox_target(box_predictions, anchors, stride):
xaya = tf.to_float(anchors_x2y2 + anchors_x1y1) * 0.5
wbhb = tf.exp(tf.minimum(
box_pred_twth, np.log(config.MAX_SIZE * 1.0 / stride))) * waha
box_pred_twth, config.BBOX_DECODE_CLIP)) * waha
xbyb = box_pred_txty * waha + xaya
x1y1 = xbyb - wbhb * 0.5
x2y2 = xbyb + wbhb * 0.5
......
......@@ -92,7 +92,7 @@ class Model(ModelDesc):
rpn_label_loss, rpn_box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)
decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors, config.ANCHOR_STRIDE) # (fHxfWxNA)x4, floatbox
decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # (fHxfWxNA)x4, floatbox
proposal_boxes, proposal_scores = generate_rpn_proposals(
decoded_boxes,
tf.reshape(rpn_label_logits, [-1]),
......@@ -131,7 +131,7 @@ class Model(ModelDesc):
fg_boxes = tf.gather(proposal_boxes, fg_ind)
fg_box_logits = fg_box_logits / tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
decoded_boxes = decode_bbox_target(fg_box_logits, fg_boxes, config.ANCHOR_STRIDE) # Nfx4, floatbox
decoded_boxes = decode_bbox_target(fg_box_logits, fg_boxes) # Nfx4, floatbox
decoded_boxes = tf.identity(decoded_boxes, name='fastrcnn_fg_boxes')
def _get_optimizer(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment