Commit 92a9315e authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] include anchors at the boundary

parent 07fcdcc1
...@@ -9,7 +9,7 @@ Tensorpack is a training interface based on TensorFlow. ...@@ -9,7 +9,7 @@ Tensorpack is a training interface based on TensorFlow.
## Features: ## Features:
It's Yet Another TF wrapper, but different in: It's Yet Another TF high-level API, with __speed__, __readability__ and __flexibility__ built together.
1. Focus on __training speed__. 1. Focus on __training speed__.
+ Speed comes for free with tensorpack -- it uses TensorFlow in the __efficient way__ with no extra overhead. + Speed comes for free with tensorpack -- it uses TensorFlow in the __efficient way__ with no extra overhead.
...@@ -33,8 +33,9 @@ See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html) ...@@ -33,8 +33,9 @@ See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html)
## [Examples](examples): ## [Examples](examples):
Instead of showing you 10 random networks trained on toy datasets, Instead of showing you 10 random networks trained on toy datasets,
[tensorpack examples](examples) faithfully replicate papers and care about performance. [tensorpack examples](examples) faithfully replicate papers and care about reproducing numbers,
And everything runs on multiple GPUs. Some highlights: demonstrating its flexibility for actual research.
Some highlights:
### Vision: ### Vision:
+ [Train ResNet](examples/ResNet) and [other models](examples/ImageNetModels) on ImageNet. + [Train ResNet](examples/ResNet) and [other models](examples/ImageNetModels) on ImageNet.
......
...@@ -34,7 +34,7 @@ def get_all_anchors( ...@@ -34,7 +34,7 @@ def get_all_anchors(
Get all anchors in the largest possible image, shifted, floatbox Get all anchors in the largest possible image, shifted, floatbox
Returns: Returns:
anchors: SxSxNUM_ANCHORx4, where S == MAX_SIZE//STRIDE, floatbox anchors: SxSxNUM_ANCHORx4, where S == ceil(MAX_SIZE/STRIDE), floatbox
The layout in the NUM_ANCHOR dim is NUM_RATIO x NUM_SCALE. The layout in the NUM_ANCHOR dim is NUM_RATIO x NUM_SCALE.
""" """
...@@ -48,7 +48,7 @@ def get_all_anchors( ...@@ -48,7 +48,7 @@ def get_all_anchors(
# anchors are intbox here. # anchors are intbox here.
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride) # anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)
field_size = config.MAX_SIZE // stride field_size = int(np.ceil(config.MAX_SIZE / stride))
shifts = np.arange(0, field_size) * stride shifts = np.arange(0, field_size) * stride
shift_x, shift_y = np.meshgrid(shifts, shifts) shift_x, shift_y = np.meshgrid(shifts, shifts)
shift_x = shift_x.flatten() shift_x = shift_x.flatten()
...@@ -154,7 +154,7 @@ def get_rpn_anchor_input(im, boxes, is_crowd): ...@@ -154,7 +154,7 @@ def get_rpn_anchor_input(im, boxes, is_crowd):
ALL_ANCHORS = get_all_anchors() ALL_ANCHORS = get_all_anchors()
H, W = im.shape[:2] H, W = im.shape[:2]
featureH, featureW = H // config.ANCHOR_STRIDE, W // config.ANCHOR_STRIDE anchorH, anchorW = ALL_ANCHORS.shape[:2]
def filter_box_inside(im, boxes): def filter_box_inside(im, boxes):
h, w = im.shape[:2] h, w = im.shape[:2]
...@@ -169,8 +169,7 @@ def get_rpn_anchor_input(im, boxes, is_crowd): ...@@ -169,8 +169,7 @@ def get_rpn_anchor_input(im, boxes, is_crowd):
non_crowd_boxes = boxes[is_crowd == 0] non_crowd_boxes = boxes[is_crowd == 0]
# fHxfWxAx4 # fHxfWxAx4
featuremap_anchors = ALL_ANCHORS[:featureH, :featureW, :, :] featuremap_anchors_flatten = np.copy(ALL_ANCHORS).reshape((-1, 4))
featuremap_anchors_flatten = featuremap_anchors.reshape((-1, 4))
# only use anchors inside the image # only use anchors inside the image
inside_ind = filter_box_inside(im, featuremap_anchors_flatten) inside_ind = filter_box_inside(im, featuremap_anchors_flatten)
inside_anchors = featuremap_anchors_flatten[inside_ind, :] inside_anchors = featuremap_anchors_flatten[inside_ind, :]
...@@ -178,12 +177,12 @@ def get_rpn_anchor_input(im, boxes, is_crowd): ...@@ -178,12 +177,12 @@ def get_rpn_anchor_input(im, boxes, is_crowd):
anchor_labels, anchor_boxes = get_anchor_labels(inside_anchors, non_crowd_boxes, crowd_boxes) anchor_labels, anchor_boxes = get_anchor_labels(inside_anchors, non_crowd_boxes, crowd_boxes)
# Fill them back to original size: fHxfWx1, fHxfWx4 # Fill them back to original size: fHxfWx1, fHxfWx4
featuremap_labels = -np.ones((featureH * featureW * config.NUM_ANCHOR, ), dtype='int32') featuremap_labels = -np.ones((anchorH * anchorW * config.NUM_ANCHOR, ), dtype='int32')
featuremap_labels[inside_ind] = anchor_labels featuremap_labels[inside_ind] = anchor_labels
featuremap_labels = featuremap_labels.reshape((featureH, featureW, config.NUM_ANCHOR)) featuremap_labels = featuremap_labels.reshape((anchorH, anchorW, config.NUM_ANCHOR))
featuremap_boxes = np.zeros((featureH * featureW * config.NUM_ANCHOR, 4), dtype='float32') featuremap_boxes = np.zeros((anchorH * anchorW * config.NUM_ANCHOR, 4), dtype='float32')
featuremap_boxes[inside_ind, :] = anchor_boxes featuremap_boxes[inside_ind, :] = anchor_boxes
featuremap_boxes = featuremap_boxes.reshape((featureH, featureW, config.NUM_ANCHOR, 4)) featuremap_boxes = featuremap_boxes.reshape((anchorH, anchorW, config.NUM_ANCHOR, 4))
return featuremap_labels, featuremap_boxes return featuremap_labels, featuremap_boxes
......
...@@ -298,7 +298,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -298,7 +298,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
@under_name_scope() @under_name_scope()
def crop_and_resize(image, boxes, box_ind, crop_size): def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
""" """
Better-aligned version of tf.image.crop_and_resize, following our definition of floating point boxes. Better-aligned version of tf.image.crop_and_resize, following our definition of floating point boxes.
...@@ -312,9 +312,11 @@ def crop_and_resize(image, boxes, box_ind, crop_size): ...@@ -312,9 +312,11 @@ def crop_and_resize(image, boxes, box_ind, crop_size):
""" """
assert isinstance(crop_size, int), crop_size assert isinstance(crop_size, int), crop_size
# TF's crop_and_resize fails on border # TF's crop_and_resize produces zeros on border
image = tf.pad(image, [[0, 0], [0, 0], [1, 1], [1, 1]]) if pad_border:
boxes = boxes + 1 # this can be quite slow
image = tf.pad(image, [[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
boxes = boxes + 1
@under_name_scope() @under_name_scope()
def transform_fpcoor_for_tf(boxes, image_shape, crop_shape): def transform_fpcoor_for_tf(boxes, image_shape, crop_shape):
......
...@@ -75,7 +75,7 @@ class Model(ModelDesc): ...@@ -75,7 +75,7 @@ class Model(ModelDesc):
image = image_preprocess(image, bgr=True) image = image_preprocess(image, bgr=True)
return tf.transpose(image, [0, 3, 1, 2]) return tf.transpose(image, [0, 3, 1, 2])
def _get_anchors(self, image): def _get_anchors(self, shape2d):
""" """
Returns: Returns:
FSxFSxNAx4 anchors, FSxFSxNAx4 anchors,
...@@ -85,9 +85,7 @@ class Model(ModelDesc): ...@@ -85,9 +85,7 @@ class Model(ModelDesc):
all_anchors = tf.constant(get_all_anchors(), name='all_anchors', dtype=tf.float32) all_anchors = tf.constant(get_all_anchors(), name='all_anchors', dtype=tf.float32)
fm_anchors = tf.slice( fm_anchors = tf.slice(
all_anchors, [0, 0, 0, 0], tf.stack([ all_anchors, [0, 0, 0, 0], tf.stack([
tf.shape(image)[0] // config.ANCHOR_STRIDE, shape2d[0], shape2d[1], -1, -1]), name='fm_anchors')
tf.shape(image)[1] // config.ANCHOR_STRIDE,
-1, -1]), name='fm_anchors')
return fm_anchors return fm_anchors
def build_graph(self, *inputs): def build_graph(self, *inputs):
...@@ -96,14 +94,24 @@ class Model(ModelDesc): ...@@ -96,14 +94,24 @@ class Model(ModelDesc):
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
else: else:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
fm_anchors = self._get_anchors(image)
image = self._preprocess(image) # 1CHW image = self._preprocess(image) # 1CHW
image_shape2d = tf.shape(image)[2:]
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR)
fm_shape = tf.shape(featuremap)[2:] # h,w
fm_anchors = self._get_anchors(fm_shape)
anchor_labels = tf.slice(
anchor_labels, [0, 0, 0],
tf.stack([fm_shape[0], fm_shape[1], -1]),
name='sliced_anchor_labels')
anchor_boxes = tf.slice(
anchor_boxes, [0, 0, 0, 0],
tf.stack([fm_shape[0], fm_shape[1], -1, -1]),
name='sliced_anchor_boxes')
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
image_shape2d = tf.shape(image)[2:] # h,w
decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox
proposal_boxes, proposal_scores = generate_rpn_proposals( proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(decoded_boxes, [-1, 4]), tf.reshape(decoded_boxes, [-1, 4]),
...@@ -170,7 +178,8 @@ class Model(ModelDesc): ...@@ -170,7 +178,8 @@ class Model(ModelDesc):
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks_for_fg, 1), tf.expand_dims(gt_masks_for_fg, 1),
fg_sampled_boxes, fg_sampled_boxes,
tf.range(tf.size(fg_inds_wrt_gt)), 14) # nfg x 1x14x14 tf.range(tf.size(fg_inds_wrt_gt)), 14,
pad_border=False) # nfg x 1x14x14
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg)
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment