Commit 7d0152ca authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] add fpn feature layers

parent aa96b7ca
...@@ -92,7 +92,7 @@ def resnet_group(l, name, block_func, features, count, stride): ...@@ -92,7 +92,7 @@ def resnet_group(l, name, block_func, features, count, stride):
return l return l
def pretrained_resnet_conv4(image, num_blocks, freeze_c2=True): def pretrained_resnet_c4_backbone(image, num_blocks, freeze_c2=True):
assert len(num_blocks) == 3 assert len(num_blocks) == 3
with resnet_argscope(): with resnet_argscope():
l = tf.pad(image, [[0, 0], [0, 0], [2, 3], [2, 3]]) l = tf.pad(image, [[0, 0], [0, 0], [2, 3], [2, 3]])
...@@ -114,3 +114,20 @@ def resnet_conv5(image, num_block): ...@@ -114,3 +114,20 @@ def resnet_conv5(image, num_block):
with resnet_argscope(): with resnet_argscope():
l = resnet_group(image, 'group3', resnet_bottleneck, 512, num_block, 2) l = resnet_group(image, 'group3', resnet_bottleneck, 512, num_block, 2)
return l return l
def pretrained_resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
assert len(num_blocks) == 4
with resnet_argscope():
l = tf.pad(image, [[0, 0], [0, 0], [2, 3], [2, 3]])
l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], [0, 1], [0, 1]])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group(l, 'group0', resnet_bottleneck, 64, num_blocks[0], 1)
if freeze_c2:
c2 = tf.stop_gradient(c2)
c3 = resnet_group(c2, 'group1', resnet_bottleneck, 128, num_blocks[1], 2)
c4 = resnet_group(c3, 'group2', resnet_bottleneck, 256, num_blocks[2], 2)
c5 = resnet_group(c4, 'group3', resnet_bottleneck, 512, num_blocks[3], 2)
# 32x downsampling up to now
return c2, c3, c4, c5
...@@ -71,3 +71,5 @@ RESULTS_PER_IM = 100 ...@@ -71,3 +71,5 @@ RESULTS_PER_IM = 100
# TODO Not Functioning. Don't USE # TODO Not Functioning. Don't USE
MODE_FPN = False MODE_FPN = False
FPN_NUM_CHANNEL = 256
FPN_SIZE_REQUIREMENT = 32
...@@ -312,6 +312,7 @@ def get_train_dataflow(add_mask=False): ...@@ -312,6 +312,7 @@ def get_train_dataflow(add_mask=False):
return None return None
ret = [im] + list(anchor_inputs) + [boxes, klass] ret = [im] + list(anchor_inputs) + [boxes, klass]
# TODO pad im when FPN
if add_mask: if add_mask:
# augmentation will modify the polys in-place # augmentation will modify the polys in-place
......
...@@ -7,7 +7,8 @@ from tensorpack.tfutils.summary import add_moving_summary ...@@ -7,7 +7,8 @@ from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import ( from tensorpack.models import (
Conv2D, FullyConnected, GlobalAvgPooling, layer_register, Conv2DTranspose) Conv2D, FullyConnected, GlobalAvgPooling, MaxPooling,
layer_register, Conv2DTranspose, FixedUnPooling)
from utils.box_ops import pairwise_iou from utils.box_ops import pairwise_iou
import config import config
...@@ -554,6 +555,32 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks): ...@@ -554,6 +555,32 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
return loss return loss
def fpn_model(features):
assert len(features) == 4, features
num_channel = config.FPN_NUM_CHANNEL
def upsample2x(x):
# TODO may not be optimal in speed or math
return FixedUnPooling(x, 2, data_format='channels_first')
with argscope(Conv2D, data_format='channels_first',
nl=tf.identity, use_bias=True,
kernel_initializer=tf.variance_scaling_initializer(scale=1.)):
lat_2345 = [Conv2D('lateral_1x1_c{}'.format(i + 2), c, num_channel, 1)
for i, c in enumerate(features)]
lat_sum_5432 = []
for idx, lat in enumerate(lat_2345[::-1]):
if idx == 0:
lat_sum_5432.append(lat)
else:
lat = lat + upsample2x(lat_sum_5432[-1])
lat_sum_5432.append(lat)
p2345 = [Conv2D('fpn_3x3_p{}'.format(i + 2), c, num_channel, 3)
for i, c in enumerate(lat_sum_5432[::-1])]
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2)
return p2345 + [p6]
if __name__ == '__main__': if __name__ == '__main__':
""" """
Demonstrate what's wrong with tf.image.crop_and_resize: Demonstrate what's wrong with tf.image.crop_and_resize:
......
...@@ -17,6 +17,7 @@ assert six.PY3, "FasterRCNN requires Python 3!" ...@@ -17,6 +17,7 @@ assert six.PY3, "FasterRCNN requires Python 3!"
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
import tensorpack.utils.viz as tpviz import tensorpack.utils.viz as tpviz
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
...@@ -24,7 +25,7 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -24,7 +25,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from coco import COCODetection from coco import COCODetection
from basemodel import ( from basemodel import (
image_preprocess, pretrained_resnet_conv4, resnet_conv5) image_preprocess, pretrained_resnet_c4_backbone, resnet_conv5)
from model import ( from model import (
clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize, clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize,
rpn_head, rpn_losses, rpn_head, rpn_losses,
...@@ -75,18 +76,22 @@ class Model(ModelDesc): ...@@ -75,18 +76,22 @@ class Model(ModelDesc):
image = image_preprocess(image, bgr=True) image = image_preprocess(image, bgr=True)
return tf.transpose(image, [0, 3, 1, 2]) return tf.transpose(image, [0, 3, 1, 2])
def _get_anchors(self, shape2d): @under_name_scope()
def slice_to_featuremap(self, featuremap, anchors, anchor_labels, anchor_boxes):
""" """
Returns: Args:
FSxFSxNAx4 anchors, Slice anchors/anchor_labels/anchor_boxes to the spatial size of this featuremap.
anchors (FS x FS x NA x 4):
anchor_labels (FS x FS x NA):
anchor_boxes (FS x FS x NA x 4):
""" """
# FSxFSxNAx4 (FS=MAX_SIZE//ANCHOR_STRIDE) shape2d = tf.shape(featuremap)[2:] # h,w
with tf.name_scope('anchors'): slice3d = tf.concat([shape2d, [-1]], axis=0)
all_anchors = tf.constant(get_all_anchors(), name='all_anchors', dtype=tf.float32) slice4d = tf.concat([shape2d, [-1, -1]], axis=0)
fm_anchors = tf.slice( anchors = tf.slice(anchors, [0, 0, 0, 0], slice4d)
all_anchors, [0, 0, 0, 0], tf.stack([ anchor_labels = tf.slice(anchor_labels, [0, 0, 0], slice3d)
shape2d[0], shape2d[1], -1, -1]), name='fm_anchors') anchor_boxes = tf.slice(anchor_boxes, [0, 0, 0, 0], slice4d)
return fm_anchors return anchors, anchor_labels, anchor_boxes
def build_graph(self, *inputs): def build_graph(self, *inputs):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
...@@ -96,19 +101,11 @@ class Model(ModelDesc): ...@@ -96,19 +101,11 @@ class Model(ModelDesc):
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
image = self._preprocess(image) # 1CHW image = self._preprocess(image) # 1CHW
featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) featuremap = pretrained_resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR)
fm_shape = tf.shape(featuremap)[2:] # h,w
fm_anchors, anchor_labels, anchor_boxes = self.slice_to_featuremap(
fm_anchors = self._get_anchors(fm_shape) featuremap, get_all_anchors(), anchor_labels, anchor_boxes)
anchor_labels = tf.slice(
anchor_labels, [0, 0, 0],
tf.stack([fm_shape[0], fm_shape[1], -1]),
name='sliced_anchor_labels')
anchor_boxes = tf.slice(
anchor_boxes, [0, 0, 0, 0],
tf.stack([fm_shape[0], fm_shape[1], -1, -1]),
name='sliced_anchor_boxes')
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment