Commit dd1d8d21 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] add arguments for rpn_head

parent 39b010ab
......@@ -14,14 +14,19 @@ from utils.box_ops import pairwise_iou
import config
def rpn_head(featuremap):
def rpn_head(featuremap, channel, num_anchors):
"""
Returns:
label_logits: fHxfWxNA
box_logits: fHxfWxNAx4
"""
with tf.variable_scope('rpn'), \
argscope(Conv2D, data_format='NCHW',
W_init=tf.random_normal_initializer(stddev=0.01)):
hidden = Conv2D('conv0', featuremap, 1024, 3, nl=tf.nn.relu)
hidden = Conv2D('conv0', featuremap, channel, 3, nl=tf.nn.relu)
label_logits = Conv2D('class', hidden, config.NR_ANCHOR, 1)
box_logits = Conv2D('box', hidden, 4 * config.NR_ANCHOR, 1)
label_logits = Conv2D('class', hidden, num_anchors, 1)
box_logits = Conv2D('box', hidden, 4 * num_anchors, 1)
# 1, NA(*4), im/16, im/16 (NCHW)
label_logits = tf.transpose(label_logits, [0, 2, 3, 1]) # 1xfHxfWxNA
......@@ -29,7 +34,7 @@ def rpn_head(featuremap):
shp = tf.shape(box_logits) # 1x(NAx4)xfHxfW
box_logits = tf.transpose(box_logits, [0, 2, 3, 1]) # 1xfHxfWx(NAx4)
box_logits = tf.reshape(box_logits, tf.stack([shp[2], shp[3], config.NR_ANCHOR, 4])) # fHxfWxNAx4
box_logits = tf.reshape(box_logits, tf.stack([shp[2], shp[3], num_anchors, 4])) # fHxfWxNAx4
return label_logits, box_logits
......
......@@ -80,7 +80,7 @@ class Model(ModelDesc):
# resnet50
featuremap = pretrained_resnet_conv4(image, [3, 4, 6])
rpn_label_logits, rpn_box_logits = rpn_head(featuremap)
rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NR_ANCHOR)
rpn_label_loss, rpn_box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment