Commit a19eb489 authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'fpn'

parents 804c5ee6 4fe9e5b1
# Faster-RCNN / Mask-RCNN on COCO
This example aims to provide a minimal (1.3k lines) implementation of
end-to-end Faster-RCNN & Mask-RCNN (with ResNet backbones) on COCO.
This example provides a minimal (only 1.6k lines) but faithful implementation the
following papers in combination:
+ [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497)
+ [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
+ [Mask R-CNN](https://arxiv.org/abs/1703.06870)
## Dependencies
+ Python 3; TensorFlow >= 1.4.0
+ Python 3; TensorFlow >= 1.4.0 (>=1.6.0 recommended due to a TF bug);
+ [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV.
+ Pre-trained [ResNet model](http://models.tensorpack.com/ResNet/) from tensorpack model zoo.
+ COCO data. It assumes the following directory structure:
......@@ -53,18 +57,21 @@ MaskRCNN results contain both bbox and segm mAP.
|Backbone|`FASTRCNN_BATCH`|resolution |schedule|mAP (bbox/segm)|Time |
| - | - | - | - | - | - |
|R-50 |64 |(600, 1024)|280k |33.1 |18h on 8 V100s|
|R-50 |512 |(800, 1333)|280k |35.6 |55h on 8 P100s|
|R-50 |512 |(800, 1333)|360k |36.6 |49h on 8 V100s|
|R-50 |256 |(800, 1333)|280k |36.8/32.1 |39h on 8 P100s|
|R-50 |512 |(800, 1333)|360k |37.8/33.2 |51h on 8 V100s|
|R-101 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s|
|R-101 |512 |(800, 1333)|360k |40.8/35.1 |63h on 8 V100s|
|R50-C4 |64 |(600, 1024)|280k |33.1 |18h on 8 V100s|
|R50-C4 |512 |(800, 1333)|280k |35.6 |55h on 8 P100s|
|R50-C4 |512 |(800, 1333)|360k |36.6 |49h on 8 V100s|
|R50-FPN |512 |(800, 1333)|360k |37.5 |28h on 8 V100s|
|R50-C4 |256 |(800, 1333)|280k |36.8/32.1 |39h on 8 P100s|
|R50-C4 |512 |(800, 1333)|360k |37.8/33.2 |51h on 8 V100s|
|R50-FPN |512 |(800, 1333)|360k |38.1/34.9 |38h on 8 V100s|
|R101-C4 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s|
|R101-C4 |512 |(800, 1333)|360k |40.8/35.1 |63h on 8 V100s|
The two R-50 360k models have the same configuration __and mAP__
as the `R50-C4-2x` entries in
[Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines).
So far this seems to be the only open source re-implementation that can reproduce mAP in Detectron.
So far this is the only TensorFlow implementation that can reproduce mAP in Detectron.
The other models listed here do not correspond to any configurations in Detectron.
## Notes
......
......@@ -92,7 +92,7 @@ def resnet_group(l, name, block_func, features, count, stride):
return l
def pretrained_resnet_c4_backbone(image, num_blocks, freeze_c2=True):
def resnet_c4_backbone(image, num_blocks, freeze_c2=True):
assert len(num_blocks) == 3
with resnet_argscope():
l = tf.pad(image, [[0, 0], [0, 0], [2, 3], [2, 3]])
......@@ -116,10 +116,19 @@ def resnet_conv5(image, num_block):
return l
def pretrained_resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
def resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
shape2d = tf.shape(image)[2:]
mult = config.FPN_RESOLUTION_REQUIREMENT * 1.
new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult)
pad_shape2d = new_shape2d - shape2d
assert len(num_blocks) == 4
# TODO pad 1 at each stage
with resnet_argscope():
l = tf.pad(image, [[0, 0], [0, 0], [2, 3], [2, 3]])
chan = image.shape[1]
l = tf.pad(image,
tf.stack([[0, 0], [0, 0],
[2, 3 + pad_shape2d[0]], [2, 3 + pad_shape2d[1]]]))
l.set_shape([None, chan, None, None])
l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], [0, 1], [0, 1]])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
......
......@@ -5,6 +5,7 @@ import numpy as np
# mode flags ---------------------
MODE_MASK = True
MODE_FPN = False
# dataset -----------------------
BASEDIR = '/path/to/your/COCO/DIR'
......@@ -34,6 +35,7 @@ MAX_SIZE = 1333
# anchors -------------------------
ANCHOR_STRIDE = 16
ANCHOR_STRIDES_FPN = (4, 8, 16, 32, 64) # strides for each FPN level. Must be the same length as ANCHOR_SIZES
FPN_RESOLUTION_REQUIREMENT = 32 # image size into the backbone has to be multiple of this number
ANCHOR_SIZES = (32, 64, 128, 256, 512) # sqrtarea of the anchor box
ANCHOR_RATIOS = (0.5, 1., 2.)
NUM_ANCHOR = len(ANCHOR_SIZES) * len(ANCHOR_RATIOS)
......@@ -48,6 +50,7 @@ RPN_MIN_SIZE = 0
RPN_PROPOSAL_NMS_THRESH = 0.7
TRAIN_PRE_NMS_TOPK = 12000
TRAIN_POST_NMS_TOPK = 2000
TRAIN_FPN_NMS_TOPK = 2000
CROWD_OVERLAP_THRES = 0.7 # boxes overlapping crowd will be ignored.
# fastrcnn training ---------------------
......@@ -56,15 +59,16 @@ FASTRCNN_BBOX_REG_WEIGHTS = np.array([10, 10, 5, 5], dtype='float32')
FASTRCNN_FG_THRESH = 0.5
FASTRCNN_FG_RATIO = 0.25 # fg ratio in a ROI batch
# modeling -------------------------
FPN_NUM_CHANNEL = 256
FASTRCNN_FC_HEAD_DIM = 1024
MASKRCNN_HEAD_DIM = 256
# testing -----------------------
TEST_PRE_NMS_TOPK = 6000
TEST_POST_NMS_TOPK = 1000 # if you encounter OOM in inference, set this to a smaller number
TEST_FPN_NMS_TOPK = 1000
FASTRCNN_NMS_THRESH = 0.5
RESULT_SCORE_THRESH = 0.05
RESULT_SCORE_THRESH_VIS = 0.3 # only visualize confident results
RESULTS_PER_IM = 100
# TODO Not Functioning. Don't USE
MODE_FPN = False
FPN_NUM_CHANNEL = 256
FPN_SIZE_REQUIREMENT = 32
......@@ -8,7 +8,7 @@ import itertools
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
imgaug, TestDataSpeed, PrefetchDataZMQ, MapData,
imgaug, TestDataSpeed, PrefetchDataZMQ, MultiProcessMapDataZMQ,
MapDataComponent, DataFromList)
# import tensorpack.utils.viz as tpviz
......@@ -51,7 +51,12 @@ def get_all_anchors(
# anchors are intbox here.
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)
field_size = int(np.ceil(config.MAX_SIZE / stride))
max_size = config.MAX_SIZE
if config.MODE_FPN:
# TODO setting this in config is perhaps better
size_mult = config.FPN_RESOLUTION_REQUIREMENT * 1.
max_size = np.ceil(max_size / size_mult) * size_mult
field_size = int(np.ceil(max_size / stride))
shifts = np.arange(0, field_size) * stride
shift_x, shift_y = np.meshgrid(shifts, shifts)
shift_x = shift_x.flatten()
......@@ -136,17 +141,19 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
overlap_with_crowd = cand_inds[ious.max(axis=1) > config.CROWD_OVERLAP_THRES]
anchor_labels[overlap_with_crowd] = -1
# Filter fg labels: ignore some fg if fg is too many
# Subsample fg labels: ignore some fg if fg is too many
target_num_fg = int(config.RPN_BATCH_PER_IM * config.RPN_FG_RATIO)
fg_inds = filter_box_label(anchor_labels, 1, target_num_fg)
if len(fg_inds) == 0:
raise MalformedData("No valid foreground for RPN!")
# Note that fg could be fewer than the target ratio
# filter bg labels. num_bg is not allowed to be too many
# Subsample bg labels. num_bg is not allowed to be too many
old_num_bg = np.sum(anchor_labels == 0)
if old_num_bg == 0 or len(fg_inds) == 0:
if old_num_bg == 0:
# No valid bg/fg in this image, skip.
# This can happen if, e.g. the image has large crowd.
raise MalformedData("No valid foreground/background for RPN!")
raise MalformedData("No valid background for RPN!")
target_num_bg = config.RPN_BATCH_PER_IM - len(fg_inds)
filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values
......@@ -336,8 +343,7 @@ def get_train_dataflow(add_mask=False):
# tpviz.interactive_imshow(viz)
return ret
ds = MapData(ds, preprocess)
ds = PrefetchDataZMQ(ds, 1)
ds = MultiProcessMapDataZMQ(ds, 10, preprocess)
return ds
......@@ -359,7 +365,6 @@ if __name__ == '__main__':
import os
from tensorpack.dataflow import PrintData
config.BASEDIR = os.path.expanduser('~/data/coco')
config.TRAIN_DATASET = ['train2014']
ds = get_train_dataflow(add_mask=config.MODE_MASK)
ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start()
......
......@@ -141,12 +141,15 @@ def print_evaluation_scores(json_file):
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
ret['mAP(bbox)'] = cocoEval.stats[0]
fields = ['IoU=0.5:0.95', 'IoU=0.5', 'IoU=0.75', 'small', 'medium', 'large']
for k in range(6):
ret['mAP(bbox)/' + fields[k]] = cocoEval.stats[k]
if config.MODE_MASK:
cocoEval = COCOeval(coco, cocoDt, 'segm')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
ret['mAP(segm)'] = cocoEval.stats[0]
for k in range(6):
ret['mAP(segm)/' + fields[k]] = cocoEval.stats[k]
return ret
......@@ -2,15 +2,18 @@
# File: model.py
import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context
import numpy as np
import itertools
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.scope_utils import under_name_scope, auto_reuse_variable_scope
from tensorpack.models import (
Conv2D, FullyConnected, GlobalAvgPooling, MaxPooling,
Conv2D, FullyConnected, MaxPooling,
layer_register, Conv2DTranspose, FixedUnPooling)
from utils.box_ops import pairwise_iou
from utils.box_ops import area as tf_area
import config
......@@ -28,6 +31,7 @@ def clip_boxes(boxes, window, name=None):
@layer_register(log_shape=True)
@auto_reuse_variable_scope
def rpn_head(featuremap, channel, num_anchors):
"""
Returns:
......@@ -67,7 +71,8 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
valid_mask = tf.stop_gradient(tf.not_equal(anchor_labels, -1))
pos_mask = tf.stop_gradient(tf.equal(anchor_labels, 1))
nr_valid = tf.stop_gradient(tf.count_nonzero(valid_mask, dtype=tf.int32), name='num_valid_anchor')
nr_pos = tf.count_nonzero(pos_mask, dtype=tf.int32, name='num_pos_anchor')
nr_pos = tf.identity(tf.count_nonzero(pos_mask, dtype=tf.int32), name='num_pos_anchor')
# nr_pos is guaranteed >0 in C4. But in FPN. even nr_valid could be 0.
valid_anchor_labels = tf.boolean_mask(anchor_labels, valid_mask)
valid_label_logits = tf.boolean_mask(label_logits, valid_mask)
......@@ -84,17 +89,22 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
valid_label_prob > th,
tf.equal(valid_prediction, valid_anchor_labels)),
dtype=tf.int32)
summaries.append(tf.truediv(
pos_prediction_corr,
nr_pos, name='recall_th{}'.format(th)))
placeholder = 0.5 # A small value will make summaries appear lower.
recall = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos))
recall = tf.where(tf.equal(nr_pos, 0), placeholder, recall, name='recall_th{}'.format(th))
precision = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos_prediction))
precision = tf.where(tf.equal(nr_pos_prediction, 0), 0.0, precision, name='precision_th{}'.format(th))
summaries.append(precision)
precision = tf.where(tf.equal(nr_pos_prediction, 0),
placeholder, precision, name='precision_th{}'.format(th))
summaries.extend([precision, recall])
add_moving_summary(*summaries)
# Per-level loss summaries in FPN may appear lower due to the use of a small placeholder.
# But the total loss is still the same.
placeholder = 0.
label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits)
label_loss = tf.reduce_mean(label_loss, name='label_loss')
label_loss = tf.reduce_sum(label_loss) * (1. / config.RPN_BATCH_PER_IM)
label_loss = tf.where(tf.equal(nr_valid, 0), placeholder, label_loss, name='label_loss')
pos_anchor_boxes = tf.boolean_mask(anchor_boxes, pos_mask)
pos_box_logits = tf.boolean_mask(box_logits, pos_mask)
......@@ -102,9 +112,8 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
box_loss = tf.losses.huber_loss(
pos_anchor_boxes, pos_box_logits, delta=delta,
reduction=tf.losses.Reduction.SUM) / delta
box_loss = tf.div(
box_loss,
tf.cast(nr_valid, tf.float32), name='box_loss')
box_loss = box_loss * (1. / config.RPN_BATCH_PER_IM)
box_loss = tf.where(tf.equal(nr_pos, 0), placeholder, box_loss, name='box_loss')
add_moving_summary(label_loss, box_loss, nr_valid, nr_pos)
return label_loss, box_loss
......@@ -167,26 +176,29 @@ def encode_bbox_target(boxes, anchors):
@under_name_scope()
def generate_rpn_proposals(boxes, scores, img_shape):
def generate_rpn_proposals(boxes, scores, img_shape,
pre_nms_topk, post_nms_topk=None):
"""
Sample RPN proposals by the following steps:
1. Pick top k1 by scores
2. NMS them
3. Pick top k2 by scores. Default k2 == k1, i.e. does not filter the NMS output.
Args:
boxes: nx4 float dtype, decoded to floatbox already
boxes: nx4 float dtype, the proposal boxes. Decoded to floatbox already
scores: n float, the logits
img_shape: [h, w]
pre_nms_topk, post_nms_topk (int): See above.
Returns:
boxes: kx4 float
scores: k logits
"""
assert boxes.shape.ndims == 2, boxes.shape
if get_current_tower_context().is_training:
PRE_NMS_TOPK = config.TRAIN_PRE_NMS_TOPK
POST_NMS_TOPK = config.TRAIN_POST_NMS_TOPK
else:
PRE_NMS_TOPK = config.TEST_PRE_NMS_TOPK
POST_NMS_TOPK = config.TEST_POST_NMS_TOPK
if post_nms_topk is None:
post_nms_topk = pre_nms_topk
topk = tf.minimum(PRE_NMS_TOPK, tf.size(scores))
topk = tf.minimum(pre_nms_topk, tf.size(scores))
topk_scores, topk_indices = tf.nn.top_k(scores, k=topk, sorted=False)
topk_boxes = tf.gather(boxes, topk_indices)
topk_boxes = clip_boxes(topk_boxes, img_shape)
......@@ -199,22 +211,21 @@ def generate_rpn_proposals(boxes, scores, img_shape):
topk_valid_boxes_x1y1x2y2 = tf.boolean_mask(topk_boxes_x1y1x2y2, valid)
topk_valid_scores = tf.boolean_mask(topk_scores, valid)
# TODO not needed
topk_valid_boxes_y1x1y2x2 = tf.reshape(
tf.reverse(topk_valid_boxes_x1y1x2y2, axis=[2]),
(-1, 4), name='nms_input_boxes')
nms_indices = tf.image.non_max_suppression(
topk_valid_boxes_y1x1y2x2,
topk_valid_scores,
max_output_size=POST_NMS_TOPK,
max_output_size=post_nms_topk,
iou_threshold=config.RPN_PROPOSAL_NMS_THRESH)
topk_valid_boxes = tf.reshape(topk_valid_boxes_x1y1x2y2, (-1, 4))
final_boxes = tf.gather(
topk_valid_boxes,
nms_indices, name='boxes')
final_scores = tf.gather(topk_valid_scores, nms_indices, name='scores')
final_boxes = tf.gather(topk_valid_boxes, nms_indices)
final_scores = tf.gather(topk_valid_scores, nms_indices)
tf.sigmoid(final_scores, name='probs') # for visualization
return final_boxes, final_scores
return tf.stop_gradient(final_boxes, name='boxes'), tf.stop_gradient(final_scores, name='scores')
@under_name_scope()
......@@ -243,6 +254,7 @@ def proposal_metrics(iou):
def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
"""
Sample some ROIs from all proposals for training.
#fg is guaranteed to be > 0, because grount truth boxes are added as RoIs.
Args:
boxes: nx4 region proposals, floatbox
......@@ -288,13 +300,15 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
fg_inds_wrt_gt = tf.gather(best_iou_ind, fg_inds) # num_fg
all_indices = tf.concat([fg_inds, bg_inds], axis=0) # indices w.r.t all n+m proposal boxes
ret_boxes = tf.gather(boxes, all_indices, name='sampled_proposal_boxes')
ret_boxes = tf.gather(boxes, all_indices)
ret_labels = tf.concat(
[tf.gather(gt_labels, fg_inds_wrt_gt),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0)
# stop the gradient -- they are meant to be ground-truth
return tf.stop_gradient(ret_boxes), tf.stop_gradient(ret_labels), fg_inds_wrt_gt
return tf.stop_gradient(ret_boxes, name='sampled_proposal_boxes'), \
tf.stop_gradient(ret_labels, name='sampled_labels'), \
tf.stop_gradient(fg_inds_wrt_gt)
@under_name_scope()
......@@ -360,37 +374,36 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
@under_name_scope()
def roi_align(featuremap, boxes, output_shape):
def roi_align(featuremap, boxes, resolution):
"""
Args:
featuremap: 1xCxHxW
boxes: Nx4 floatbox
output_shape: int
resolution: output spatial resolution
Returns:
NxCxoHxoW
NxCx res x res
"""
boxes = tf.stop_gradient(boxes) # TODO
# sample 4 locations per roi bin
ret = crop_and_resize(
featuremap, boxes,
tf.zeros([tf.shape(boxes)[0]], dtype=tf.int32),
output_shape * 2)
resolution * 2)
ret = tf.nn.avg_pool(ret, [1, 1, 2, 2], [1, 1, 2, 2], padding='SAME', data_format='NCHW')
return ret
@layer_register(log_shape=True)
def fastrcnn_head(feature, num_classes):
def fastrcnn_outputs(feature, num_classes):
"""
Args:
feature (NxCx7x7):
feature (any shape):
num_classes(int): num_category + 1
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
feature = GlobalAvgPooling('gap', feature, data_format='channels_first')
classification = FullyConnected(
'class', feature, num_classes,
kernel_initializer=tf.random_normal_initializer(stddev=0.01))
......@@ -401,6 +414,23 @@ def fastrcnn_head(feature, num_classes):
return classification, box_regression
@layer_register(log_shape=True)
def fastrcnn_2fc_head(feature, num_classes):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
dim = config.FASTRCNN_FC_HEAD_DIM
init = tf.variance_scaling_initializer()
hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, nl=tf.nn.relu)
hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, nl=tf.nn.relu)
return fastrcnn_outputs('outputs', hidden, num_classes)
@under_name_scope()
def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
"""
......@@ -498,20 +528,24 @@ def fastrcnn_predictions(boxes, probs):
@layer_register(log_shape=True)
def maskrcnn_head(feature, num_class):
def maskrcnn_upXconv_head(feature, num_class, num_convs):
"""
Args:
feature (NxCx7x7):
feature (NxCx s x s): size is 7 in C4 models and 14 in FPN models.
num_classes(int): num_category + 1
num_convs (int): number of convolution layers
Returns:
mask_logits (N x num_category x 14 x 14):
mask_logits (N x num_category x 2s x 2s):
"""
l = feature
with argscope([Conv2D, Conv2DTranspose], data_format='channels_first',
kernel_initializer=tf.variance_scaling_initializer(
scale=2.0, mode='fan_out', distribution='normal')):
# c2's MSRAFill is fan_out
l = Conv2DTranspose('deconv', feature, 256, 2, strides=2, activation=tf.nn.relu)
for k in range(num_convs):
l = Conv2D('fcn{}'.format(k), l, config.MASKRCNN_HEAD_DIM, 3, activation=tf.nn.relu)
l = Conv2DTranspose('deconv', l, config.MASKRCNN_HEAD_DIM, 2, strides=2, activation=tf.nn.relu)
l = Conv2D('conv', l, num_class - 1, 1)
return l
......@@ -520,13 +554,13 @@ def maskrcnn_head(feature, num_class):
def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
"""
Args:
mask_logits: #fg x #category x14x14
mask_logits: #fg x #category xhxw
fg_labels: #fg, in 1~#class
fg_target_masks: #fgx14x14, int
fg_target_masks: #fgxhxw, int
"""
num_fg = tf.size(fg_labels)
indices = tf.stack([tf.range(num_fg), tf.to_int32(fg_labels) - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fgx14x14
mask_logits = tf.gather_nd(mask_logits, indices) # #fgxhxw
mask_probs = tf.sigmoid(mask_logits)
# add some training visualizations to tensorboard
......@@ -555,13 +589,31 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
return loss
@layer_register(log_shape=True)
def fpn_model(features):
"""
Args:
features ([tf.Tensor]): ResNet features c2-c5
Returns:
[tf.Tensor]: FPN features p2-p6
"""
assert len(features) == 4, features
num_channel = config.FPN_NUM_CHANNEL
def upsample2x(x):
# TODO may not be optimal in speed or math
return FixedUnPooling(x, 2, data_format='channels_first')
def upsample2x(name, x):
return FixedUnPooling(
name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'),
data_format='channels_first')
# tf.image.resize is, again, not aligned.
# with tf.name_scope(name):
# logger.info("Nearest neighbor")
# shape2d = tf.shape(x)[2:]
# x = tf.transpose(x, [0, 2, 3, 1])
# x = tf.image.resize_nearest_neighbor(x, shape2d * 2, align_corners=True)
# x = tf.transpose(x, [0, 3, 1, 2])
# return x
with argscope(Conv2D, data_format='channels_first',
nl=tf.identity, use_bias=True,
......@@ -573,19 +625,81 @@ def fpn_model(features):
if idx == 0:
lat_sum_5432.append(lat)
else:
lat = lat + upsample2x(lat_sum_5432[-1])
lat = lat + upsample2x('upsample_lat{}'.format(6 - idx), lat_sum_5432[-1])
lat_sum_5432.append(lat)
p2345 = [Conv2D('fpn_3x3_p{}'.format(i + 2), c, num_channel, 3)
p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3)
for i, c in enumerate(lat_sum_5432[::-1])]
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2)
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2, data_format='channels_first')
return p2345 + [p6]
@under_name_scope()
def fpn_map_rois_to_levels(boxes):
"""
Assign boxes to level 2~5.
Args:
boxes (nx4)
Returns:
[tf.Tensor]: 4 tensors for level 2-5. Each tensor is a vector of indices of boxes in its level.
[tf.Tensor]: 4 tensors, the gathered boxes in each level.
Be careful that the returned tensor could be empty.
"""
sqrtarea = tf.sqrt(tf_area(boxes))
level = tf.to_int32(tf.floor(
4 + tf.log(sqrtarea * (1. / 224) + 1e-6) * (1.0 / np.log(2))))
# RoI levels range from 2~5 (not 6)
level_ids = [
tf.where(level <= 2),
tf.where(tf.equal(level, 3)), # == is not supported
tf.where(tf.equal(level, 4)),
tf.where(level >= 5)]
level_ids = [tf.reshape(x, [-1], name='roi_level{}_id'.format(i + 2))
for i, x in enumerate(level_ids)]
num_in_levels = [tf.size(x, name='num_roi_level{}'.format(i + 2))
for i, x in enumerate(level_ids)]
add_moving_summary(*num_in_levels)
level_boxes = [tf.gather(boxes, ids) for ids in level_ids]
return level_ids, level_boxes
@under_name_scope()
def multilevel_roi_align(features, rcnn_boxes, resolution):
"""
Args:
features ([tf.Tensor]): 4 FPN feature level 2-5
rcnn_boxes (tf.Tensor): nx4 boxes
resolution (int): output spatial resolution
Returns:
NxC x res x res
"""
assert len(features) == 4, features
# Reassign rcnn_boxes to levels
level_ids, level_boxes = fpn_map_rois_to_levels(rcnn_boxes)
all_rois = []
# Crop patches from corresponding levels
for i, boxes, featuremap in zip(itertools.count(), level_boxes, features):
with tf.name_scope('roi_level{}'.format(i + 2)):
boxes_on_featuremap = boxes * (1.0 / config.ANCHOR_STRIDES_FPN[i])
all_rois.append(roi_align(featuremap, boxes_on_featuremap, resolution))
all_rois = tf.concat(all_rois, axis=0) # NCHW
# Unshuffle to the original order, to match the original samples
level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N
level_id_invert_perm = tf.invert_permutation(level_id_perm)
all_rois = tf.gather(all_rois, level_id_invert_perm)
return all_rois
if __name__ == '__main__':
"""
Demonstrate what's wrong with tf.image.crop_and_resize:
"""
import numpy as np
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
......
......@@ -19,22 +19,25 @@ from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.common import get_tf_version_number
import tensorpack.utils.viz as tpviz
from tensorpack.utils.gpu import get_nr_gpu
from coco import COCODetection
from basemodel import (
image_preprocess, pretrained_resnet_c4_backbone, resnet_conv5)
image_preprocess, resnet_c4_backbone, resnet_conv5,
resnet_fpn_backbone)
from model import (
clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize,
rpn_head, rpn_losses,
generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
fastrcnn_head, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_head, maskrcnn_loss)
fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_upXconv_head, maskrcnn_loss,
fpn_model, fastrcnn_2fc_head, multilevel_roi_align)
from data import (
get_train_dataflow, get_eval_dataflow,
get_all_anchors)
get_all_anchors, get_all_anchors_fpn)
from viz import (
draw_annotation, draw_proposal_recall,
draw_predictions, draw_final_outputs)
......@@ -57,7 +60,16 @@ def get_model_output_names():
return ret
class Model(ModelDesc):
def get_model():
if config.MODE_FPN:
if get_tf_version() < 1.6:
logger.warn("FPN has chances to crash in TF<1.6, due to a TF issue.")
return ResNetFPNModel()
else:
return ResNetC4Model()
class DetectionModel(ModelDesc):
def inputs(self):
ret = [
tf.placeholder(tf.float32, (None, None, 3), 'image'),
......@@ -71,13 +83,13 @@ class Model(ModelDesc):
) # NR_GT x height x width
return ret
def _preprocess(self, image):
def preprocess(self, image):
image = tf.expand_dims(image, 0)
image = image_preprocess(image, bgr=True)
return tf.transpose(image, [0, 3, 1, 2])
@under_name_scope()
def slice_to_featuremap(self, featuremap, anchors, anchor_labels, anchor_boxes):
def narrow_to_featuremap(self, featuremap, anchors, anchor_labels, anchor_boxes):
"""
Args:
Slice anchors/anchor_labels/anchor_boxes to the spatial size of this featuremap.
......@@ -93,49 +105,130 @@ class Model(ModelDesc):
anchor_boxes = tf.slice(anchor_boxes, [0, 0, 0, 0], slice4d)
return anchors, anchor_labels, anchor_boxes
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
tf.summary.scalar('learning_rate', lr)
factor = get_batch_factor()
if factor != 1:
lr = lr / float(factor)
opt = tf.train.MomentumOptimizer(lr, 0.9)
opt = optimizer.AccumGradOptimizer(opt, factor)
else:
opt = tf.train.MomentumOptimizer(lr, 0.9)
return opt
def fastrcnn_training(self, image,
rcnn_labels, fg_rcnn_boxes, gt_boxes_per_fg,
rcnn_label_logits, fg_rcnn_box_logits):
"""
Args:
image (NCHW):
rcnn_labels (n): labels for each sampled targets
fg_rcnn_boxes (fg x 4): proposal boxes for each sampled foreground targets
gt_boxes_per_fg (fg x 4): matching gt boxes for each sampled foreground targets
rcnn_label_logits (n): label logits for each sampled targets
fg_rcnn_box_logits (fg x 4): box logits for each sampled foreground targets
"""
with tf.name_scope('fg_sample_patch_viz'):
fg_sampled_patches = crop_and_resize(
image, fg_rcnn_boxes,
tf.zeros(tf.shape(fg_rcnn_boxes)[0], dtype=tf.int32), 300)
fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1])
fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB
tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
encoded_boxes = encode_bbox_target(
gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, rcnn_label_logits,
encoded_boxes,
fg_rcnn_box_logits)
return fastrcnn_label_loss, fastrcnn_box_loss
def fastrcnn_inference(self, image_shape2d,
rcnn_boxes, rcnn_label_logits, rcnn_box_logits):
"""
Args:
image_shape2d: h, w
rcnn_boxes (nx4): the proposal boxes
rcnn_label_logits (n):
rcnn_box_logits (nx4):
Returns:
boxes (mx4):
labels (m): each >= 1
"""
label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class
anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4
decoded_boxes = decode_bbox_target(
rcnn_box_logits /
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
# indices: Nx2. Each index into (#proposal, #category)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
final_probs = tf.identity(final_probs, 'final_probs')
final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
return final_boxes, final_labels
class ResNetC4Model(DetectionModel):
def build_graph(self, *inputs):
is_training = get_current_tower_context().is_training
if config.MODE_MASK:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
else:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
image = self._preprocess(image) # 1CHW
image = self.preprocess(image) # 1CHW
featuremap = pretrained_resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3])
featuremap = resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR)
fm_anchors, anchor_labels, anchor_boxes = self.slice_to_featuremap(
fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap(
featuremap, get_all_anchors(), anchor_labels, anchor_boxes)
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
image_shape2d = tf.shape(image)[2:] # h,w
decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox
pred_boxes_decoded = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox
proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(decoded_boxes, [-1, 4]),
tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]),
image_shape2d)
image_shape2d,
config.TRAIN_PRE_NMS_TOPK if is_training else config.TEST_PRE_NMS_TOPK,
config.TRAIN_POST_NMS_TOPK if is_training else config.TEST_POST_NMS_TOPK)
if is_training:
# sample proposal boxes in training
rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels)
boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE)
else:
# use all proposal boxes in inference
boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE)
# The boxes to be used to crop RoIs.
# Use all proposal boxes in inference
rcnn_boxes = proposal_boxes
boxes_on_featuremap = rcnn_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
# which was fixed in TF 1.6
def ff_true():
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS)
feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, config.NUM_CLASS)
# Return C5 feature to be shared with mask branch
return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits
def ff_false():
ncls = config.NUM_CLASS
return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4])
if get_tf_version_number() >= 1.6:
feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = ff_true()
else:
logger.warn("This example may drop support for TF < 1.6 soon.")
feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)
......@@ -145,35 +238,27 @@ class Model(ModelDesc):
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)
# fastrcnn loss
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
with tf.name_scope('fg_sample_patch_viz'):
fg_sampled_patches = crop_and_resize(
image, fg_sampled_boxes,
tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300)
fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1])
fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB
tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
encoded_boxes = encode_bbox_target(
matched_gt_boxes,
fg_sampled_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, fastrcnn_label_logits,
encoded_boxes,
tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample))
fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
if config.MODE_MASK:
# maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
# In training, mask branch shares the same C5 feature.
fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
mask_logits = maskrcnn_head('maskrcnn', fg_feature, config.NUM_CLASS) # #fg x #cat x 14x14
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', fg_feature, config.NUM_CLASS, num_convs=0) # #fg x #cat x 14x14
gt_masks_for_fg = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W
matched_gt_masks = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W
target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks_for_fg, 1),
tf.expand_dims(matched_gt_masks, 1),
fg_sampled_boxes,
tf.range(tf.size(fg_inds_wrt_gt)), 14,
pad_border=False) # nfg x 1x14x14
......@@ -195,53 +280,176 @@ class Model(ModelDesc):
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
label_probs = tf.nn.softmax(fastrcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class
anchors = tf.tile(tf.expand_dims(proposal_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4
decoded_boxes = decode_bbox_target(
fastrcnn_box_logits /
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
# indices: Nx2. Each index into (#proposal, #category)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
final_probs = tf.identity(final_probs, 'final_probs')
final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if config.MODE_MASK:
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
def f1():
roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1])
mask_logits = maskrcnn_head(
'maskrcnn', feature_maskrcnn, config.NUM_CLASS) # #result x #cat x 14x14
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', feature_maskrcnn, config.NUM_CLASS, 0) # #result x #cat x 14x14
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14
return tf.sigmoid(final_mask_logits)
final_masks = tf.cond(tf.size(final_probs) > 0, f1, lambda: tf.zeros([0, 14, 14]))
final_masks = tf.cond(tf.size(final_labels) > 0, f1, lambda: tf.zeros([0, 14, 14]))
tf.identity(final_masks, name='final_masks')
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
tf.summary.scalar('learning_rate', lr)
factor = get_batch_factor()
if factor != 1:
lr = lr / float(factor)
opt = tf.train.MomentumOptimizer(lr, 0.9)
opt = optimizer.AccumGradOptimizer(opt, factor)
class ResNetFPNModel(DetectionModel):
def inputs(self):
ret = [
tf.placeholder(tf.float32, (None, None, 3), 'image')]
num_anchors = len(config.ANCHOR_RATIOS)
for k in range(len(config.ANCHOR_STRIDES_FPN)):
ret.extend([
tf.placeholder(tf.int32, (None, None, num_anchors),
'anchor_labels_lvl{}'.format(k + 2)),
tf.placeholder(tf.float32, (None, None, num_anchors, 4),
'anchor_boxes_lvl{}'.format(k + 2))])
ret.extend([
tf.placeholder(tf.float32, (None, 4), 'gt_boxes'),
tf.placeholder(tf.int64, (None,), 'gt_labels')]) # all > 0
if config.MODE_MASK:
ret.append(
tf.placeholder(tf.uint8, (None, None, None), 'gt_masks')
) # NR_GT x height x width
return ret
def build_graph(self, *inputs):
num_fpn_level = len(config.ANCHOR_STRIDES_FPN)
assert len(config.ANCHOR_SIZES) == num_fpn_level
is_training = get_current_tower_context().is_training
image = inputs[0]
input_anchors = inputs[1: 1 + 2 * num_fpn_level]
multilevel_anchor_labels = input_anchors[0::2]
multilevel_anchor_boxes = input_anchors[1::2]
gt_boxes, gt_labels = inputs[11], inputs[12]
if config.MODE_MASK:
gt_masks = inputs[-1]
image = self.preprocess(image) # 1CHW
image_shape2d = tf.shape(image)[2:] # h,w
c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK)
p23456 = fpn_model('fpn', c2345)
# Multi-Level RPN Proposals
multilevel_proposals = []
rpn_loss_collection = []
for lvl in range(num_fpn_level):
rpn_label_logits, rpn_box_logits = rpn_head(
'rpn', p23456[lvl], config.FPN_NUM_CHANNEL, len(config.ANCHOR_RATIOS))
with tf.name_scope('FPN_lvl{}'.format(lvl + 2)):
anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2))
anchors, anchor_labels, anchor_boxes = \
self.narrow_to_featuremap(p23456[lvl], anchors,
multilevel_anchor_labels[lvl],
multilevel_anchor_boxes[lvl])
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, anchors)
pred_boxes_decoded = decode_bbox_target(rpn_box_logits, anchors)
proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]),
image_shape2d,
config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK)
multilevel_proposals.append((proposal_boxes, proposal_scores))
if is_training:
label_loss, box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded,
rpn_label_logits, rpn_box_logits)
rpn_loss_collection.extend([label_loss, box_loss])
# Merge proposals from multi levels, pick top K
proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4
proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n
proposal_topk = tf.minimum(tf.size(proposal_scores),
config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK)
proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False)
proposal_boxes = tf.gather(proposal_boxes, topk_indices)
if is_training:
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels)
else:
opt = tf.train.MomentumOptimizer(lr, 0.9)
return opt
# The boxes to be used to crop RoIs.
rcnn_boxes = proposal_boxes
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head(
'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS)
if is_training:
# rpn loss is already defined above
with tf.name_scope('rpn_losses'):
rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2], name='label_loss')
rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2], name='box_loss')
add_moving_summary(rpn_total_box_loss, rpn_total_label_loss)
# fastrcnn loss:
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)
fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
if config.MODE_MASK:
# maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], fg_sampled_boxes, 14)
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28
matched_gt_masks = tf.gather(gt_masks, fg_inds_wrt_gt) # fg x H x W
target_masks_for_fg = crop_and_resize(
tf.expand_dims(matched_gt_masks, 1),
fg_sampled_boxes,
tf.range(tf.size(fg_inds_wrt_gt)), 28,
pad_border=False) # fg x 1x28x28
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg)
else:
mrcnn_loss = 0.0
wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(1e-4), name='wd_cost')
total_cost = tf.add_n(rpn_loss_collection + [
fastrcnn_label_loss, fastrcnn_box_loss,
mrcnn_loss, wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if config.MODE_MASK:
# Cascade inference needs roi transform with refined boxes.
roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], final_boxes, 14)
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28
tf.sigmoid(final_mask_logits, name='final_masks')
def visualize(model_path, nr_visualize=50, output_dir='output'):
assert not config.MODE_FPN, "FPN visualize is not supported yet!"
df = get_train_dataflow() # we don't visualize mask stuff
df.reset_state()
pred = OfflinePredictor(PredictConfig(
model=Model(),
model=ResNetC4Model(),
session_init=get_model_loader(model_path),
input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[
......@@ -320,7 +528,11 @@ class EvalCallback(Callback):
logger.get_logger_dir(), 'outputs{}.json'.format(self.global_step))
with open(output_file, 'w') as f:
json.dump(all_results, f)
try:
scores = print_evaluation_scores(output_file)
except Exception:
logger.exception("Exception in COCO evaluation.")
scores = {}
for k, v in scores.items():
self.trainer.monitors.put_scalar(k, v)
......@@ -361,7 +573,7 @@ if __name__ == '__main__':
visualize(args.load)
else:
pred = OfflinePredictor(PredictConfig(
model=Model(),
model=get_model(),
session_init=get_model_loader(args.load),
input_names=['image'],
output_names=get_model_output_names()))
......@@ -387,12 +599,13 @@ if __name__ == '__main__':
(steps * factor // stepnum, config.BASE_LR * mult))
cfg = TrainConfig(
model=Model(),
model=get_model(),
data=QueueInput(get_train_dataflow(add_mask=config.MODE_MASK)),
callbacks=[
PeriodicCallback(
ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1),
every_k_epochs=20),
SessionRunTimeout(60000), # 1 minute timeout
# linear warmup
ScheduledHyperParamSetter(
'learning_rate', warmup_schedule, interp='linear', step_based=True),
......@@ -402,8 +615,8 @@ if __name__ == '__main__':
EstimatedTimeLeft(),
],
steps_per_epoch=stepnum,
max_epoch=config.LR_SCHEDULE[2] * factor // stepnum,
max_epoch=config.LR_SCHEDULE[-1] * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None,
)
trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu())
trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu(), mode='cpu')
launch_train_with_config(cfg, trainer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment