Commit b7f10ccf authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] add sample patches to tensorboard

parent a0601fb7
...@@ -47,7 +47,7 @@ To evaluate the performance (pretrained models can be downloaded in [model zoo]( ...@@ -47,7 +47,7 @@ To evaluate the performance (pretrained models can be downloaded in [model zoo](
Mean Average Precision @IoU=0.50:0.95: Mean Average Precision @IoU=0.50:0.95:
+ trainval35k/minival, FASTRCNN_BATCH=256: 34.2. Takes 49h on 8 TitanX. + trainval35k/minival, FASTRCNN_BATCH=256: 34.2. Takes 49h on 8 TitanX.
+ trainval35k/minival, FASTRCNN_BATCH=64: 32.7. Takes 31h on 8 TitanX. + trainval35k/minival, FASTRCNN_BATCH=64: 32.7. Takes 25h on 8 TitanX.
The hyperparameters are not carefully tuned. You can probably get better performance by e.g. training longer. The hyperparameters are not carefully tuned. You can probably get better performance by e.g. training longer.
......
...@@ -25,7 +25,7 @@ from coco import COCODetection ...@@ -25,7 +25,7 @@ from coco import COCODetection
from basemodel import ( from basemodel import (
image_preprocess, pretrained_resnet_conv4, resnet_conv5) image_preprocess, pretrained_resnet_conv4, resnet_conv5)
from model import ( from model import (
clip_boxes, decode_bbox_target, encode_bbox_target, clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize,
rpn_head, rpn_losses, rpn_head, rpn_losses,
generate_rpn_proposals, sample_fast_rcnn_targets, roi_align, generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
fastrcnn_head, fastrcnn_losses, fastrcnn_predictions) fastrcnn_head, fastrcnn_losses, fastrcnn_predictions)
...@@ -81,7 +81,7 @@ class Model(ModelDesc): ...@@ -81,7 +81,7 @@ class Model(ModelDesc):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
fm_anchors = self._get_anchors(image) fm_anchors = self._get_anchors(image)
image = self._preprocess(image) image = self._preprocess(image) # 1CHW
image_shape2d = tf.shape(image)[2:] image_shape2d = tf.shape(image)[2:]
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
...@@ -116,6 +116,13 @@ class Model(ModelDesc): ...@@ -116,6 +116,13 @@ class Model(ModelDesc):
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample) fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample)
with tf.name_scope('fg_sample_patch_viz'):
fg_sampled_patches = crop_and_resize(
image, fg_sampled_boxes,
tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), [300, 300])
fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1])
tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
encoded_boxes = encode_bbox_target( encoded_boxes = encode_bbox_target(
matched_gt_boxes, matched_gt_boxes,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment