Commit e3f463ab authored by Yuxin Wu's avatar Yuxin Wu

FPN+mask

parent 4f13f971
......@@ -4,7 +4,7 @@
import numpy as np
# mode flags ---------------------
MODE_MASK = False
MODE_MASK = True
# dataset -----------------------
BASEDIR = '/path/to/your/COCO/DIR'
......@@ -25,7 +25,7 @@ WARMUP = 1000 # in steps
STEPS_PER_EPOCH = 500
LR_SCHEDULE = [150000, 230000, 280000]
LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
#LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# image resolution --------------------
SHORT_EDGE_SIZE = 800
......@@ -73,6 +73,7 @@ RESULTS_PER_IM = 100
# TODO Not Functioning. Don't USE
MODE_FPN = True
FPN_NUM_CHANNEL = 256
MASKRCNN_HEAD_DIM = 256
FASTRCNN_FC_HEAD_DIM = 1024
FPN_RESOLUTION_REQUIREMENT = 32
TRAIN_FPN_NMS_TOPK = 2000
......
......@@ -344,7 +344,6 @@ def get_train_dataflow(add_mask=False):
return ret
ds = MultiProcessMapDataZMQ(ds, 10, preprocess)
#ds = PrefetchDataZMQ(ds, 3)
return ds
......
......@@ -141,12 +141,15 @@ def print_evaluation_scores(json_file):
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
ret['mAP(bbox)'] = cocoEval.stats[0]
fields = ['IoU=0.5:0.95', 'IoU=0.5', 'IoU=0.75', 'small', 'medium', 'large']
for k in range(6):
ret['mAP(bbox)/' + fields[k]] = cocoEval.stat[k]
if config.MODE_MASK:
cocoEval = COCOeval(coco, cocoDt, 'segm')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
ret['mAP(segm)'] = cocoEval.stats[0]
for k in range(6):
ret['mAP(segm)/' + fields[k]] = cocoEval.stats[k]
return ret
......@@ -3,6 +3,8 @@
import tensorflow as tf
import numpy as np
import itertools
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope, auto_reuse_variable_scope
......@@ -371,22 +373,22 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
@under_name_scope()
def roi_align(featuremap, boxes, output_shape):
def roi_align(featuremap, boxes, resolution):
"""
Args:
featuremap: 1xCxHxW
boxes: Nx4 floatbox
output_shape: int
resolution: output spatial resolution
Returns:
NxCxoHxoW
NxCx res x res
"""
boxes = tf.stop_gradient(boxes) # TODO
# sample 4 locations per roi bin
ret = crop_and_resize(
featuremap, boxes,
tf.zeros([tf.shape(boxes)[0]], dtype=tf.int32),
output_shape * 2)
resolution * 2)
ret = tf.nn.avg_pool(ret, [1, 1, 2, 2], [1, 1, 2, 2], padding='SAME', data_format='NCHW')
return ret
......@@ -411,6 +413,25 @@ def fastrcnn_outputs(feature, num_classes):
return classification, box_regression
@layer_register(log_shape=True)
def fastrcnn_2fc_head(feature, num_classes):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
dim = config.FASTRCNN_FC_HEAD_DIM
logger.info("fc-head-xavier-fanin")
#init = tf.random_normal_initializer(stddev=0.01)
init = tf.variance_scaling_initializer()
hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, nl=tf.nn.relu)
hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, nl=tf.nn.relu)
return fastrcnn_outputs('outputs', hidden, num_classes)
@under_name_scope()
def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
"""
......@@ -508,20 +529,24 @@ def fastrcnn_predictions(boxes, probs):
@layer_register(log_shape=True)
def maskrcnn_head(feature, num_class):
def maskrcnn_upXconv_head(feature, num_class, num_convs):
"""
Args:
feature (NxCx7x7):
feature (NxCx s x s): size is 7 in C4 models and 14 in FPN models.
num_classes(int): num_category + 1
num_convs (int): number of convolution layers
Returns:
mask_logits (N x num_category x 14 x 14):
mask_logits (N x num_category x 2s x 2s):
"""
l = feature
with argscope([Conv2D, Conv2DTranspose], data_format='channels_first',
kernel_initializer=tf.variance_scaling_initializer(
scale=2.0, mode='fan_out', distribution='normal')):
# c2's MSRAFill is fan_out
l = Conv2DTranspose('deconv', feature, 256, 2, strides=2, activation=tf.nn.relu)
for k in range(num_convs):
l = Conv2D('fcn{}'.format(k), l, config.MASKRCNN_HEAD_DIM, 3, activation=tf.nn.relu)
l = Conv2DTranspose('deconv', l, config.MASKRCNN_HEAD_DIM, 2, strides=2, activation=tf.nn.relu)
l = Conv2D('conv', l, num_class - 1, 1)
return l
......@@ -530,13 +555,13 @@ def maskrcnn_head(feature, num_class):
def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
"""
Args:
mask_logits: #fg x #category x14x14
mask_logits: #fg x #category xhxw
fg_labels: #fg, in 1~#class
fg_target_masks: #fgx14x14, int
fg_target_masks: #fgxhxw, int
"""
num_fg = tf.size(fg_labels)
indices = tf.stack([tf.range(num_fg), tf.to_int32(fg_labels) - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fgx14x14
mask_logits = tf.gather_nd(mask_logits, indices) # #fgxhxw
mask_probs = tf.sigmoid(mask_logits)
# add some training visualizations to tensorboard
......@@ -642,22 +667,33 @@ def fpn_map_rois_to_levels(boxes):
return level_ids, level_boxes
@layer_register(log_shape=True)
def fastrcnn_2fc_head(feature, dim, num_classes):
@under_name_scope()
def multilevel_roi_align(features, rcnn_boxes, resolution):
"""
Args:
feature (any shape):
dim (int): mlp dim
num_classes(int): num_category + 1
features ([tf.Tensor]): 4 FPN feature level 2-5
rcnn_boxes (tf.Tensor): nx4 boxes
resolution (int): output spatial resolution
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
NxC x res x res
"""
logger.info("fc-head-stddev=0.01")
init = tf.random_normal_initializer(stddev=0.01)
hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, nl=tf.nn.relu)
hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, nl=tf.nn.relu)
return fastrcnn_outputs('outputs', hidden, num_classes)
assert len(features) == 4, features
# Reassign rcnn_boxes to levels
level_ids, level_boxes = fpn_map_rois_to_levels(rcnn_boxes)
all_rois = []
# Crop patches from corresponding levels
for i, boxes, featuremap in zip(itertools.count(), level_boxes, features):
with tf.name_scope('roi_level{}'.format(i + 2)):
boxes_on_featuremap = boxes * (1.0 / config.ANCHOR_STRIDES_FPN[i])
all_rois.append(roi_align(featuremap, boxes_on_featuremap, resolution))
all_rois = tf.concat(all_rois, axis=0) # NCHW
# Unshuffle to the original order, to match the original samples
level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N
level_id_invert_perm = tf.invert_permutation(level_id_perm)
all_rois = tf.gather(all_rois, level_id_invert_perm)
return all_rois
if __name__ == '__main__':
......
......@@ -32,8 +32,9 @@ from model import (
rpn_head, rpn_losses,
generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_head, maskrcnn_loss,
fpn_model, fpn_map_rois_to_levels, fastrcnn_2fc_head)
maskrcnn_upXconv_head, maskrcnn_loss,
fpn_model, fpn_map_rois_to_levels, fastrcnn_2fc_head,
multilevel_roi_align)
from data import (
get_train_dataflow, get_eval_dataflow,
get_all_anchors, get_all_anchors_fpn)
......@@ -245,11 +246,12 @@ class ResNetC4Model(DetectionModel):
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
# In training, mask branch shares the same C5 feature.
fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
mask_logits = maskrcnn_head('maskrcnn', fg_feature, config.NUM_CLASS) # #fg x #cat x 14x14
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', fg_feature, config.NUM_CLASS, 0) # #fg x #cat x 14x14
gt_masks_for_fg = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W
matched_gt_masks = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W
target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks_for_fg, 1),
tf.expand_dims(matched_gt_masks, 1),
fg_sampled_boxes,
tf.range(tf.size(fg_inds_wrt_gt)), 14,
pad_border=False) # nfg x 1x14x14
......@@ -279,8 +281,8 @@ class ResNetC4Model(DetectionModel):
def f1():
roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1])
mask_logits = maskrcnn_head(
'maskrcnn', feature_maskrcnn, config.NUM_CLASS) # #result x #cat x 14x14
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', feature_maskrcnn, config.NUM_CLASS, 0) # #result x #cat x 14x14
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14
return tf.sigmoid(final_mask_logits)
......@@ -370,25 +372,13 @@ class ResNetFPNModel(DetectionModel):
# The boxes to be used to crop RoIs.
rcnn_boxes = proposal_boxes
# Reassign rcnn_boxes to levels
level_ids, level_boxes = fpn_map_rois_to_levels(rcnn_boxes)
all_rois = []
# Crop patches from corresponding levels
for i, boxes, featuremap in zip(itertools.count(), level_boxes, p23456[:4]):
with tf.name_scope('roi_level{}'.format(i + 2)):
boxes_on_featuremap = boxes * (1.0 / config.ANCHOR_STRIDES_FPN[i])
all_rois.append(roi_align(featuremap, boxes_on_featuremap, 7))
all_rois = tf.concat(all_rois, axis=0) # NCHW
# Unshuffle to the original order, to match the original samples
level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N
level_id_invert_perm = tf.invert_permutation(level_id_perm)
all_rois = tf.gather(all_rois, level_id_invert_perm)
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head(
'fastrcnn', all_rois, config.FASTRCNN_FC_HEAD_DIM, config.NUM_CLASS)
'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS)
if is_training:
# rpn loss ...
with tf.name_scope('rpn_losses'):
rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2], name='label_loss')
rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2], name='box_loss')
......@@ -405,7 +395,24 @@ class ResNetFPNModel(DetectionModel):
image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
mrcnn_loss = 0.0
if config.MODE_MASK:
# maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], fg_sampled_boxes, 14)
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28
matched_gt_masks = tf.gather(gt_masks, fg_inds_wrt_gt) # fg x H x W
target_masks_for_fg = crop_and_resize(
tf.expand_dims(matched_gt_masks, 1),
fg_sampled_boxes,
tf.range(tf.size(fg_inds_wrt_gt)), 28,
pad_border=False) # fg x 1x28x28
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg)
else:
mrcnn_loss = 0.0
wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
......@@ -420,6 +427,14 @@ class ResNetFPNModel(DetectionModel):
else:
final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if config.MODE_MASK:
roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], final_boxes, 14)
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28
final_masks = tf.sigmoid(final_mask_logits, name='final_masks')
def visualize(model_path, nr_visualize=50, output_dir='output'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment