Commit 89ce9046 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] small updates on docs

parent e9fa7eb6
......@@ -63,7 +63,7 @@ CROWD_OVERLAP_THRES = 0.7 # boxes overlapping crowd will be ignored.
# fastrcnn training ---------------------
FASTRCNN_BATCH_PER_IM = 512
FASTRCNN_BBOX_REG_WEIGHTS = np.array([10, 10, 5, 5], dtype='float32')
FASTRCNN_BBOX_REG_WEIGHTS = [10., 10., 5., 5.] # Better but non-standard setting: [20, 20, 10, 10]
FASTRCNN_FG_THRESH = 0.5
FASTRCNN_FG_RATIO = 0.25 # fg ratio in a ROI batch
......
......@@ -154,8 +154,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
# Subsample bg labels. num_bg is not allowed to be too many
old_num_bg = np.sum(anchor_labels == 0)
if old_num_bg == 0:
# No valid bg/fg in this image, skip.
# This can happen if, e.g. the image has large crowd.
# No valid bg in this image, skip.
raise MalformedData("No valid background for RPN!")
target_num_bg = config.RPN_BATCH_PER_IM - len(fg_inds)
filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values
......
......@@ -306,7 +306,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
ret_labels = tf.concat(
[tf.gather(gt_labels, fg_inds_wrt_gt),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0)
# stop the gradient -- they are meant to be ground-truth
# stop the gradient -- they are meant to be training targets
return tf.stop_gradient(ret_boxes, name='sampled_proposal_boxes'), \
tf.stop_gradient(ret_labels, name='sampled_labels'), \
tf.stop_gradient(fg_inds_wrt_gt)
......
......@@ -76,8 +76,9 @@ class DetectionModel(ModelDesc):
@under_name_scope()
def narrow_to_featuremap(self, featuremap, anchors, anchor_labels, anchor_boxes):
"""
Args:
Slice anchors/anchor_labels/anchor_boxes to the spatial size of this featuremap.
Args:
anchors (FS x FS x NA x 4):
anchor_labels (FS x FS x NA):
anchor_boxes (FS x FS x NA x 4):
......@@ -112,7 +113,7 @@ class DetectionModel(ModelDesc):
fg_rcnn_boxes (fg x 4): proposal boxes for each sampled foreground targets
gt_boxes_per_fg (fg x 4): matching gt boxes for each sampled foreground targets
rcnn_label_logits (n): label logits for each sampled targets
fg_rcnn_box_logits (fg x 4): box logits for each sampled foreground targets
fg_rcnn_box_logits (fg x #class x 4): box logits for each sampled foreground targets
"""
with tf.name_scope('fg_sample_patch_viz'):
......@@ -124,7 +125,7 @@ class DetectionModel(ModelDesc):
tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
encoded_boxes = encode_bbox_target(
gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS, dtype=tf.float32)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, rcnn_label_logits,
encoded_boxes,
......@@ -138,7 +139,7 @@ class DetectionModel(ModelDesc):
image_shape2d: h, w
rcnn_boxes (nx4): the proposal boxes
rcnn_label_logits (n):
rcnn_box_logits (nx4):
rcnn_box_logits (nx #class x 4):
Returns:
boxes (mx4):
......@@ -148,7 +149,7 @@ class DetectionModel(ModelDesc):
anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4
decoded_boxes = decode_bbox_target(
rcnn_box_logits /
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS, dtype=tf.float32), anchors)
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
# indices: Nx2. Each index into (#proposal, #category)
......@@ -332,7 +333,8 @@ class ResNetFPNModel(DetectionModel):
c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK)
p23456 = fpn_model('fpn', c2345)
# images are padded for p5, which are too large for p2-p4
# Images are padded for p5, which are too large for p2-p4.
# This seems to have no effect on mAP.
for i, stride in enumerate(config.ANCHOR_STRIDES_FPN[:3]):
pi = p23456[i]
target_shape = tf.to_int32(tf.ceil(tf.to_float(image_shape2d) * (1.0 / stride)))
......@@ -423,7 +425,7 @@ class ResNetFPNModel(DetectionModel):
mrcnn_loss = 0.0
wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
'(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(1e-4), name='wd_cost')
total_cost = tf.add_n(rpn_loss_collection + [
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment