Commit 17c25692 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] don't use default args, they are evaluated too early

parent 0050d6e7
...@@ -29,9 +29,7 @@ class MalformedData(BaseException): ...@@ -29,9 +29,7 @@ class MalformedData(BaseException):
@memoized @memoized
def get_all_anchors( def get_all_anchors(stride=None, sizes=None):
stride=cfg.RPN.ANCHOR_STRIDE,
sizes=cfg.RPN.ANCHOR_SIZES):
""" """
Get all anchors in the largest possible image, shifted, floatbox Get all anchors in the largest possible image, shifted, floatbox
Args: Args:
...@@ -43,6 +41,10 @@ def get_all_anchors( ...@@ -43,6 +41,10 @@ def get_all_anchors(
The layout in the NUM_ANCHOR dim is NUM_RATIO x NUM_SIZE. The layout in the NUM_ANCHOR dim is NUM_RATIO x NUM_SIZE.
""" """
if stride is None:
stride = cfg.RPN.ANCHOR_STRIDE
if sizes is None:
sizes = cfg.RPN.ANCHOR_SIZES
# Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors # Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
# are centered on stride / 2, have (approximate) sqrt areas of the specified # are centered on stride / 2, have (approximate) sqrt areas of the specified
# sizes, and aspect ratios as given. # sizes, and aspect ratios as given.
...@@ -69,20 +71,23 @@ def get_all_anchors( ...@@ -69,20 +71,23 @@ def get_all_anchors(
shifts.reshape((1, K, 4)).transpose((1, 0, 2))) shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
field_of_anchors = field_of_anchors.reshape((field_size, field_size, A, 4)) field_of_anchors = field_of_anchors.reshape((field_size, field_size, A, 4))
# FSxFSxAx4 # FSxFSxAx4
assert np.all(field_of_anchors == field_of_anchors.astype('int32')) # Many rounding happens inside the anchor code anyway
# assert np.all(field_of_anchors == field_of_anchors.astype('int32'))
field_of_anchors = field_of_anchors.astype('float32') field_of_anchors = field_of_anchors.astype('float32')
field_of_anchors[:, :, :, [2, 3]] += 1 field_of_anchors[:, :, :, [2, 3]] += 1
return field_of_anchors return field_of_anchors
@memoized @memoized
def get_all_anchors_fpn( def get_all_anchors_fpn(strides=None, sizes=None):
strides=cfg.FPN.ANCHOR_STRIDES,
sizes=cfg.RPN.ANCHOR_SIZES):
""" """
Returns: Returns:
[anchors]: each anchors is a SxSx NUM_ANCHOR_RATIOS x4 array. [anchors]: each anchors is a SxSx NUM_ANCHOR_RATIOS x4 array.
""" """
if strides is None:
strides = cfg.FPN.ANCHOR_STRIDES
if sizes is None:
sizes = cfg.RPN.ANCHOR_SIZES
assert len(strides) == len(sizes) assert len(strides) == len(sizes)
foas = [] foas = []
for stride, size in zip(strides, sizes): for stride, size in zip(strides, sizes):
......
...@@ -122,6 +122,7 @@ def multilevel_roi_align(features, rcnn_boxes, resolution): ...@@ -122,6 +122,7 @@ def multilevel_roi_align(features, rcnn_boxes, resolution):
boxes_on_featuremap = boxes * (1.0 / cfg.FPN.ANCHOR_STRIDES[i]) boxes_on_featuremap = boxes * (1.0 / cfg.FPN.ANCHOR_STRIDES[i])
all_rois.append(roi_align(featuremap, boxes_on_featuremap, resolution)) all_rois.append(roi_align(featuremap, boxes_on_featuremap, resolution))
# this can fail if using TF<=1.8 with MKL build
all_rois = tf.concat(all_rois, axis=0) # NCHW all_rois = tf.concat(all_rois, axis=0) # NCHW
# Unshuffle to the original order, to match the original samples # Unshuffle to the original order, to match the original samples
level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment