Commit 26b13269 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] fix image shape bug..

parent 8543db08
...@@ -75,17 +75,17 @@ class Model(ModelDesc): ...@@ -75,17 +75,17 @@ class Model(ModelDesc):
all_anchors = tf.constant(get_all_anchors(), name='all_anchors', dtype=tf.float32) all_anchors = tf.constant(get_all_anchors(), name='all_anchors', dtype=tf.float32)
fm_anchors = tf.slice( fm_anchors = tf.slice(
all_anchors, [0, 0, 0, 0], tf.stack([ all_anchors, [0, 0, 0, 0], tf.stack([
tf.shape(image)[0] // config.ANCHOR_STRIDE,
tf.shape(image)[1] // config.ANCHOR_STRIDE, tf.shape(image)[1] // config.ANCHOR_STRIDE,
tf.shape(image)[2] // config.ANCHOR_STRIDE,
-1, -1]), name='fm_anchors') -1, -1]), name='fm_anchors')
return fm_anchors return fm_anchors
def _build_graph(self, inputs): def _build_graph(self, inputs):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
fm_anchors = self._get_anchors(image)
image = self._preprocess(image) image = self._preprocess(image)
fm_anchors = self._get_anchors(image)
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NR_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NR_ANCHOR)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment