Commit 1fc18a6e authored by Yuxin Wu's avatar Yuxin Wu

FPN initial commit

parent 04a64849
......@@ -117,9 +117,18 @@ def resnet_conv5(image, num_block):
def pretrained_resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
shape2d = tf.shape(image)[2:]
mult = config.FPN_RESOLUTION_REQUIREMENT * 1.
new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult)
pad_shape2d = new_shape2d - shape2d
assert len(num_blocks) == 4
# TODO pad 1 at each stage
with resnet_argscope():
l = tf.pad(image, [[0, 0], [0, 0], [2, 3], [2, 3]])
chan = image.shape[1]
l = tf.pad(image,
tf.stack([[0, 0], [0, 0],
[2, 3 + pad_shape2d[0]], [2, 3 + pad_shape2d[1]]]))
l.set_shape([None, chan, None, None])
l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], [0, 1], [0, 1]])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
......
......@@ -4,11 +4,12 @@
import numpy as np
# mode flags ---------------------
MODE_MASK = True
MODE_MASK = False
# dataset -----------------------
BASEDIR = '/path/to/your/COCO/DIR'
TRAIN_DATASET = ['train2014', 'valminusminival2014']
# TRAIN_DATASET = ['valminusminival2014']
VAL_DATASET = 'minival2014' # only support evaluation on single dataset
NUM_CLASS = 81
CLASS_NAMES = [] # NUM_CLASS strings. Will be populated later by coco loader
......@@ -23,12 +24,12 @@ BASE_LR = 1e-2
WARMUP = 1000 # in steps
STEPS_PER_EPOCH = 500
LR_SCHEDULE = [150000, 230000, 280000]
# LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
# LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# image resolution --------------------
SHORT_EDGE_SIZE = 800
MAX_SIZE = 1333
MAX_SIZE = 1333 # TODO use 1344
# alternative (worse & faster) setting: 600, 1024
# anchors -------------------------
......@@ -55,7 +56,7 @@ TRAIN_POST_NMS_TOPK = 2000
CROWD_OVERLAP_THRES = 0.7
# fastrcnn training ---------------------
FASTRCNN_BATCH_PER_IM = 256
FASTRCNN_BATCH_PER_IM = 512
FASTRCNN_BBOX_REG_WEIGHTS = np.array([10, 10, 5, 5], dtype='float32')
FASTRCNN_FG_THRESH = 0.5
# fg ratio in a ROI batch
......@@ -70,6 +71,9 @@ RESULT_SCORE_THRESH_VIS = 0.3 # only visualize confident results
RESULTS_PER_IM = 100
# TODO Not Functioning. Don't USE
MODE_FPN = False
MODE_FPN = True
FPN_NUM_CHANNEL = 256
FPN_SIZE_REQUIREMENT = 32
FASTRCNN_FC_HEAD_DIM = 1024
FPN_RESOLUTION_REQUIREMENT = 32
TRAIN_FPN_NMS_TOPK = 2048
TEST_FPN_NMS_TOPK = 1024
......@@ -51,7 +51,12 @@ def get_all_anchors(
# anchors are intbox here.
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)
field_size = int(np.ceil(config.MAX_SIZE / stride))
max_size = config.MAX_SIZE
if config.MODE_FPN:
# TODO setting this in config is perhaps better
size_mult = config.FPN_RESOLUTION_REQUIREMENT * 1.
max_size = np.ceil(max_size / size_mult) * size_mult
field_size = int(np.ceil(max_size / stride))
shifts = np.arange(0, field_size) * stride
shift_x, shift_y = np.meshgrid(shifts, shifts)
shift_x = shift_x.flatten()
......@@ -337,7 +342,7 @@ def get_train_dataflow(add_mask=False):
return ret
ds = MapData(ds, preprocess)
ds = PrefetchDataZMQ(ds, 1)
ds = PrefetchDataZMQ(ds, 3)
return ds
......@@ -359,7 +364,6 @@ if __name__ == '__main__':
import os
from tensorpack.dataflow import PrintData
config.BASEDIR = os.path.expanduser('~/data/coco')
config.TRAIN_DATASET = ['train2014']
ds = get_train_dataflow(add_mask=config.MODE_MASK)
ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start()
......
......@@ -2,15 +2,16 @@
# File: model.py
import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context
import numpy as np
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.scope_utils import under_name_scope, auto_reuse_variable_scope
from tensorpack.models import (
Conv2D, FullyConnected, GlobalAvgPooling, MaxPooling,
Conv2D, FullyConnected, MaxPooling,
layer_register, Conv2DTranspose, FixedUnPooling)
from utils.box_ops import pairwise_iou
from utils.box_ops import area as tf_area
import config
......@@ -28,6 +29,7 @@ def clip_boxes(boxes, window, name=None):
@layer_register(log_shape=True)
@auto_reuse_variable_scope
def rpn_head(featuremap, channel, num_anchors):
"""
Returns:
......@@ -67,7 +69,8 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
valid_mask = tf.stop_gradient(tf.not_equal(anchor_labels, -1))
pos_mask = tf.stop_gradient(tf.equal(anchor_labels, 1))
nr_valid = tf.stop_gradient(tf.count_nonzero(valid_mask, dtype=tf.int32), name='num_valid_anchor')
nr_pos = tf.count_nonzero(pos_mask, dtype=tf.int32, name='num_pos_anchor')
nr_pos = tf.identity(tf.count_nonzero(pos_mask, dtype=tf.int32), name='num_pos_anchor')
# nr_pos is guaranteed >0 in C4. But in FPN. even nr_valid could be 0.
valid_anchor_labels = tf.boolean_mask(anchor_labels, valid_mask)
valid_label_logits = tf.boolean_mask(label_logits, valid_mask)
......@@ -84,17 +87,20 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
valid_label_prob > th,
tf.equal(valid_prediction, valid_anchor_labels)),
dtype=tf.int32)
summaries.append(tf.truediv(
pos_prediction_corr,
nr_pos, name='recall_th{}'.format(th)))
placeholder = 0.5 # TODO A small value will make summaries appear lower.
recall = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos))
recall = tf.where(tf.equal(nr_pos, 0), placeholder, recall, name='recall_th{}'.format(th))
precision = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos_prediction))
precision = tf.where(tf.equal(nr_pos_prediction, 0), 0.0, precision, name='precision_th{}'.format(th))
summaries.append(precision)
precision = tf.where(tf.equal(nr_pos_prediction, 0),
placeholder, precision, name='precision_th{}'.format(th))
summaries.extend([precision, recall])
add_moving_summary(*summaries)
placeholder = 1.
label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits)
label_loss = tf.reduce_mean(label_loss, name='label_loss')
label_loss = tf.reduce_mean(label_loss)
label_loss = tf.where(tf.equal(nr_valid, 0), placeholder, label_loss, name='label_loss')
pos_anchor_boxes = tf.boolean_mask(anchor_boxes, pos_mask)
pos_box_logits = tf.boolean_mask(box_logits, pos_mask)
......@@ -104,7 +110,8 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
reduction=tf.losses.Reduction.SUM) / delta
box_loss = tf.div(
box_loss,
tf.cast(nr_valid, tf.float32), name='box_loss')
tf.cast(nr_valid, tf.float32))
box_loss = tf.where(tf.equal(nr_pos, 0), placeholder, box_loss, name='box_loss')
add_moving_summary(label_loss, box_loss, nr_valid, nr_pos)
return label_loss, box_loss
......@@ -167,26 +174,29 @@ def encode_bbox_target(boxes, anchors):
@under_name_scope()
def generate_rpn_proposals(boxes, scores, img_shape):
def generate_rpn_proposals(boxes, scores, img_shape,
pre_nms_topk, post_nms_topk=None):
"""
Sample RPN proposals by the following steps:
1. Pick top k1 by scores
2. NMS them
3. Pick top k2 by scores. Default k2 == k1, i.e. does not filter the NMS output.
Args:
boxes: nx4 float dtype, decoded to floatbox already
boxes: nx4 float dtype, the proposal boxes. Decoded to floatbox already
scores: n float, the logits
img_shape: [h, w]
pre_nms_topk, post_nms_topk (int): See above.
Returns:
boxes: kx4 float
scores: k logits
"""
assert boxes.shape.ndims == 2, boxes.shape
if get_current_tower_context().is_training:
PRE_NMS_TOPK = config.TRAIN_PRE_NMS_TOPK
POST_NMS_TOPK = config.TRAIN_POST_NMS_TOPK
else:
PRE_NMS_TOPK = config.TEST_PRE_NMS_TOPK
POST_NMS_TOPK = config.TEST_POST_NMS_TOPK
topk = tf.minimum(PRE_NMS_TOPK, tf.size(scores))
if post_nms_topk is None:
post_nms_topk = pre_nms_topk
topk = tf.minimum(pre_nms_topk, tf.size(scores))
topk_scores, topk_indices = tf.nn.top_k(scores, k=topk, sorted=False)
topk_boxes = tf.gather(boxes, topk_indices)
topk_boxes = clip_boxes(topk_boxes, img_shape)
......@@ -199,22 +209,21 @@ def generate_rpn_proposals(boxes, scores, img_shape):
topk_valid_boxes_x1y1x2y2 = tf.boolean_mask(topk_boxes_x1y1x2y2, valid)
topk_valid_scores = tf.boolean_mask(topk_scores, valid)
# TODO not needed
topk_valid_boxes_y1x1y2x2 = tf.reshape(
tf.reverse(topk_valid_boxes_x1y1x2y2, axis=[2]),
(-1, 4), name='nms_input_boxes')
nms_indices = tf.image.non_max_suppression(
topk_valid_boxes_y1x1y2x2,
topk_valid_scores,
max_output_size=POST_NMS_TOPK,
max_output_size=post_nms_topk,
iou_threshold=config.RPN_PROPOSAL_NMS_THRESH)
topk_valid_boxes = tf.reshape(topk_valid_boxes_x1y1x2y2, (-1, 4))
final_boxes = tf.gather(
topk_valid_boxes,
nms_indices, name='boxes')
final_scores = tf.gather(topk_valid_scores, nms_indices, name='scores')
final_boxes = tf.gather(topk_valid_boxes, nms_indices)
final_scores = tf.gather(topk_valid_scores, nms_indices)
tf.sigmoid(final_scores, name='probs') # for visualization
return final_boxes, final_scores
return tf.stop_gradient(final_boxes, name='boxes'), tf.stop_gradient(final_scores, name='scores')
@under_name_scope()
......@@ -243,6 +252,7 @@ def proposal_metrics(iou):
def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
"""
Sample some ROIs from all proposals for training.
#fg is guaranteed to be > 0, because grount truth boxes are added as RoIs.
Args:
boxes: nx4 region proposals, floatbox
......@@ -288,13 +298,15 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
fg_inds_wrt_gt = tf.gather(best_iou_ind, fg_inds) # num_fg
all_indices = tf.concat([fg_inds, bg_inds], axis=0) # indices w.r.t all n+m proposal boxes
ret_boxes = tf.gather(boxes, all_indices, name='sampled_proposal_boxes')
ret_boxes = tf.gather(boxes, all_indices)
ret_labels = tf.concat(
[tf.gather(gt_labels, fg_inds_wrt_gt),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0)
# stop the gradient -- they are meant to be ground-truth
return tf.stop_gradient(ret_boxes), tf.stop_gradient(ret_labels), fg_inds_wrt_gt
return tf.stop_gradient(ret_boxes, name='sampled_proposal_boxes'), \
tf.stop_gradient(ret_labels, name='sampled_labels'), \
tf.stop_gradient(fg_inds_wrt_gt)
@under_name_scope()
......@@ -381,16 +393,15 @@ def roi_align(featuremap, boxes, output_shape):
@layer_register(log_shape=True)
def fastrcnn_head(feature, num_classes):
def fastrcnn_outputs(feature, num_classes):
"""
Args:
feature (NxCx7x7):
feature (any shape):
num_classes(int): num_category + 1
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
feature = GlobalAvgPooling('gap', feature, data_format='channels_first')
classification = FullyConnected(
'class', feature, num_classes,
kernel_initializer=tf.random_normal_initializer(stddev=0.01))
......@@ -555,13 +566,21 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
return loss
@layer_register(log_shape=True)
def fpn_model(features):
"""
Args:
features ([tf.Tensor]): ResNet features c2-c5
Returns:
[tf.Tensor]: FPN features p2-p6
"""
assert len(features) == 4, features
num_channel = config.FPN_NUM_CHANNEL
def upsample2x(x):
def upsample2x(name, x):
# TODO may not be optimal in speed or math
return FixedUnPooling(x, 2, data_format='channels_first')
return FixedUnPooling(name, x, 2, data_format='channels_first')
with argscope(Conv2D, data_format='channels_first',
nl=tf.identity, use_bias=True,
......@@ -573,19 +592,66 @@ def fpn_model(features):
if idx == 0:
lat_sum_5432.append(lat)
else:
lat = lat + upsample2x(lat_sum_5432[-1])
lat = lat + upsample2x('upsample_c{}'.format(5 - idx), lat_sum_5432[-1])
lat_sum_5432.append(lat)
p2345 = [Conv2D('fpn_3x3_p{}'.format(i + 2), c, num_channel, 3)
p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3)
for i, c in enumerate(lat_sum_5432[::-1])]
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2)
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2, data_format='channels_first')
return p2345 + [p6]
@under_name_scope()
def fpn_map_rois_to_levels(boxes):
"""
Args:
boxes (nx4)
Returns:
[tf.Tensor]: 4 tensors for level 2-5. Each tensor is a vector of indices of boxes in its level.
[tf.Tensor]: 4 tensors, the gathered boxes in each level.
Be careful that the returned tensor could be empty.
"""
sqrtarea = tf.sqrt(tf_area(boxes))
level = tf.floor(4 + tf.log(sqrtarea * (1. / 224) + 1e-6) * (1.0 / np.log(2)))
# RoI levels range from 2~5 (not 6)
level_ids = [
tf.where(level <= 2),
tf.where(tf.equal(level, 3)), # == is not supported
tf.where(tf.equal(level, 4)),
tf.where(level >= 5)]
level_ids = [tf.reshape(x, [-1], name='roi_level{}_id'.format(i + 2))
for i, x in enumerate(level_ids)]
num_in_levels = [tf.size(x, name='num_roi_level{}'.format(i + 2))
for i, x in enumerate(level_ids)]
add_moving_summary(*num_in_levels)
level_boxes = [tf.gather(boxes, ids) for ids in level_ids]
return level_ids, level_boxes
@layer_register(log_shape=True)
def fastrcnn_2fc_head(feature, dim, num_classes):
"""
Args:
feature (any shape):
dim (int): mlp dim
num_classes(int): num_category + 1
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
init = tf.random_normal_initializer(stddev=0.01)
hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, nl=tf.nn.relu)
hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, nl=tf.nn.relu)
return fastrcnn_outputs('outputs', hidden, num_classes)
if __name__ == '__main__':
"""
Demonstrate what's wrong with tf.image.crop_and_resize:
"""
import numpy as np
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
......
......@@ -25,16 +25,18 @@ from tensorpack.utils.gpu import get_nr_gpu
from coco import COCODetection
from basemodel import (
image_preprocess, pretrained_resnet_c4_backbone, resnet_conv5)
image_preprocess, pretrained_resnet_c4_backbone, resnet_conv5,
pretrained_resnet_fpn_backbone)
from model import (
clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize,
rpn_head, rpn_losses,
generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
fastrcnn_head, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_head, maskrcnn_loss)
fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_head, maskrcnn_loss,
fpn_model, fpn_map_rois_to_levels, fastrcnn_2fc_head)
from data import (
get_train_dataflow, get_eval_dataflow,
get_all_anchors)
get_all_anchors, get_all_anchors_fpn)
from viz import (
draw_annotation, draw_proposal_recall,
draw_predictions, draw_final_outputs)
......@@ -57,7 +59,14 @@ def get_model_output_names():
return ret
class Model(ModelDesc):
def get_model():
if config.MODE_FPN:
return ResNetFPNModel()
else:
return ResNetC4Model()
class DetectionModel(ModelDesc):
def inputs(self):
ret = [
tf.placeholder(tf.float32, (None, None, 3), 'image'),
......@@ -71,13 +80,13 @@ class Model(ModelDesc):
) # NR_GT x height x width
return ret
def _preprocess(self, image):
def preprocess(self, image):
image = tf.expand_dims(image, 0)
image = image_preprocess(image, bgr=True)
return tf.transpose(image, [0, 3, 1, 2])
@under_name_scope()
def slice_to_featuremap(self, featuremap, anchors, anchor_labels, anchor_boxes):
def narrow_to_featuremap(self, featuremap, anchors, anchor_labels, anchor_boxes):
"""
Args:
Slice anchors/anchor_labels/anchor_boxes to the spatial size of this featuremap.
......@@ -93,43 +102,119 @@ class Model(ModelDesc):
anchor_boxes = tf.slice(anchor_boxes, [0, 0, 0, 0], slice4d)
return anchors, anchor_labels, anchor_boxes
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
tf.summary.scalar('learning_rate', lr)
factor = get_batch_factor()
if factor != 1:
lr = lr / float(factor)
opt = tf.train.MomentumOptimizer(lr, 0.9)
opt = optimizer.AccumGradOptimizer(opt, factor)
else:
opt = tf.train.MomentumOptimizer(lr, 0.9)
return opt
def fastrcnn_training(self, image,
rcnn_labels, fg_rcnn_boxes, gt_boxes_per_fg,
rcnn_label_logits, fg_rcnn_box_logits):
"""
Args:
image (NCHW):
rcnn_labels (n): labels for each sampled targets
fg_rcnn_boxes (fg x 4): proposal boxes for each sampled foreground targets
gt_boxes_per_fg (fg x 4): matching gt boxes for each sampled foreground targets
rcnn_label_logits (n): label logits for each sampled targets
fg_rcnn_box_logits (fg x 4): box logits for each sampled foreground targets
"""
with tf.name_scope('fg_sample_patch_viz'):
fg_sampled_patches = crop_and_resize(
image, fg_rcnn_boxes,
tf.zeros(tf.shape(fg_rcnn_boxes)[0], dtype=tf.int32), 300)
fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1])
fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB
tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
encoded_boxes = encode_bbox_target(
gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, rcnn_label_logits,
encoded_boxes,
fg_rcnn_box_logits)
return fastrcnn_label_loss, fastrcnn_box_loss
def fastrcnn_inference(self, image_shape2d,
rcnn_boxes, rcnn_label_logits, rcnn_box_logits):
"""
Args:
image_shape2d: h, w
rcnn_boxes (nx4): the proposal boxes
rcnn_label_logits (n):
rcnn_box_logits (nx4):
Returns:
boxes (mx4):
labels (m): each >= 1
"""
label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class
anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4
decoded_boxes = decode_bbox_target(
rcnn_box_logits /
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
# indices: Nx2. Each index into (#proposal, #category)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
final_probs = tf.identity(final_probs, 'final_probs')
final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
return final_boxes, final_labels
class ResNetC4Model(DetectionModel):
def build_graph(self, *inputs):
is_training = get_current_tower_context().is_training
if config.MODE_MASK:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
else:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
image = self._preprocess(image) # 1CHW
image = self.preprocess(image) # 1CHW
featuremap = pretrained_resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR)
fm_anchors, anchor_labels, anchor_boxes = self.slice_to_featuremap(
fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap(
featuremap, get_all_anchors(), anchor_labels, anchor_boxes)
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
image_shape2d = tf.shape(image)[2:] # h,w
decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox
pred_boxes_decoded = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox
proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(decoded_boxes, [-1, 4]),
tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]),
image_shape2d)
image_shape2d,
config.TRAIN_PRE_NMS_TOPK if is_training else config.TEST_PRE_NMS_TOPK,
config.TRAIN_POST_NMS_TOPK if is_training else config.TEST_POST_NMS_TOPK)
if is_training:
# sample proposal boxes in training
rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels)
boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE)
else:
# use all proposal boxes in inference
boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE)
# The boxes to be used to crop RoIs.
# Use all proposal boxes in inference
rcnn_boxes = proposal_boxes
boxes_on_featuremap = rcnn_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
def ff_true():
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS)
feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, config.NUM_CLASS)
# Return C5 feature to be shared with mask branch
return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits
def ff_false():
......@@ -145,29 +230,20 @@ class Model(ModelDesc):
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)
# fastrcnn loss
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
with tf.name_scope('fg_sample_patch_viz'):
fg_sampled_patches = crop_and_resize(
image, fg_sampled_boxes,
tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300)
fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1])
fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB
tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
encoded_boxes = encode_bbox_target(
matched_gt_boxes,
fg_sampled_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, fastrcnn_label_logits,
encoded_boxes,
tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample))
fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
if config.MODE_MASK:
# maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
# In training, mask branch shares the same C5 feature.
fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
mask_logits = maskrcnn_head('maskrcnn', fg_feature, config.NUM_CLASS) # #fg x #cat x 14x14
......@@ -195,18 +271,8 @@ class Model(ModelDesc):
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
label_probs = tf.nn.softmax(fastrcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class
anchors = tf.tile(tf.expand_dims(proposal_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4
decoded_boxes = decode_bbox_target(
fastrcnn_box_logits /
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
# indices: Nx2. Each index into (#proposal, #category)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
final_probs = tf.identity(final_probs, 'final_probs')
final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if config.MODE_MASK:
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
......@@ -219,21 +285,137 @@ class Model(ModelDesc):
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14
return tf.sigmoid(final_mask_logits)
final_masks = tf.cond(tf.size(final_probs) > 0, f1, lambda: tf.zeros([0, 14, 14]))
final_masks = tf.cond(tf.size(final_labels) > 0, f1, lambda: tf.zeros([0, 14, 14]))
tf.identity(final_masks, name='final_masks')
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
tf.summary.scalar('learning_rate', lr)
factor = get_batch_factor()
if factor != 1:
lr = lr / float(factor)
opt = tf.train.MomentumOptimizer(lr, 0.9)
opt = optimizer.AccumGradOptimizer(opt, factor)
class ResNetFPNModel(DetectionModel):
def inputs(self):
ret = [
tf.placeholder(tf.float32, (None, None, 3), 'image')]
num_anchors = len(config.ANCHOR_RATIOS)
for k in range(len(config.ANCHOR_STRIDES_FPN)):
ret.extend([
tf.placeholder(tf.int32, (None, None, num_anchors),
'anchor_labels_lvl{}'.format(k + 2)),
tf.placeholder(tf.float32, (None, None, num_anchors, 4),
'anchor_boxes_lvl{}'.format(k + 2))])
ret.extend([
tf.placeholder(tf.float32, (None, 4), 'gt_boxes'),
tf.placeholder(tf.int64, (None,), 'gt_labels')]) # all > 0
if config.MODE_MASK:
ret.append(
tf.placeholder(tf.uint8, (None, None, None), 'gt_masks')
) # NR_GT x height x width
return ret
def build_graph(self, *inputs):
num_fpn_level = len(config.ANCHOR_STRIDES_FPN)
assert len(config.ANCHOR_SIZES) == num_fpn_level
is_training = get_current_tower_context().is_training
image = inputs[0]
input_anchors = inputs[1: 1 + 2 * num_fpn_level]
multilevel_anchor_labels = input_anchors[0::2]
multilevel_anchor_boxes = input_anchors[1::2]
gt_boxes, gt_labels = inputs[11], inputs[12]
if config.MODE_MASK:
gt_masks = inputs[-1]
image = self.preprocess(image) # 1CHW
image_shape2d = tf.shape(image)[2:] # h,w
c2345 = pretrained_resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK)
p23456 = fpn_model('fpn', c2345)
# Multi-Level RPN Proposals
multilevel_anchors = get_all_anchors_fpn()
assert len(multilevel_anchors) == num_fpn_level
multilevel_proposals = []
rpn_loss_collection = []
for lvl in range(num_fpn_level):
rpn_label_logits, rpn_box_logits = rpn_head(
'rpn', p23456[lvl], config.FPN_NUM_CHANNEL, len(config.ANCHOR_RATIOS))
with tf.name_scope('FPN_lvl{}'.format(lvl + 2)):
anchors, anchor_labels, anchor_boxes = \
self.narrow_to_featuremap(p23456[lvl], multilevel_anchors[lvl],
multilevel_anchor_labels[lvl],
multilevel_anchor_boxes[lvl])
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, anchors)
pred_boxes_decoded = decode_bbox_target(rpn_box_logits, anchors)
proposal_boxes, proposal_scores = generate_rpn_proposals(
tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]),
image_shape2d,
config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK)
multilevel_proposals.append((proposal_boxes, proposal_scores))
if is_training:
label_loss, box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded,
rpn_label_logits, rpn_box_logits)
rpn_loss_collection.extend([label_loss, box_loss])
# merge proposals from multi levels
proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4
proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n
proposal_topk = tf.minimum(tf.size(proposal_scores),
config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK)
proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False)
proposal_boxes = tf.gather(proposal_boxes, topk_indices)
if is_training:
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels)
else:
opt = tf.train.MomentumOptimizer(lr, 0.9)
return opt
# The boxes to be used to crop RoIs.
rcnn_boxes = proposal_boxes
# Reassign rcnn_boxes to levels
level_ids, level_boxes = fpn_map_rois_to_levels(rcnn_boxes)
all_rois = []
# Crop patches from corresponding levels
for i, boxes, featuremap in zip(itertools.count(), level_boxes, p23456[:4]):
with tf.name_scope('roi_level{}'.format(i + 2)):
boxes_on_featuremap = boxes * (1.0 / config.ANCHOR_STRIDES_FPN[i])
all_rois.append(roi_align(featuremap, boxes_on_featuremap, 7))
all_rois = tf.concat(all_rois, axis=0) # NCHW
# Unshuffle to the original order, to match the original samples
level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N
level_id_invert_perm = tf.invert_permutation(level_id_perm)
all_rois = tf.gather(all_rois, level_id_invert_perm)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head(
'fastrcnn', all_rois, config.FASTRCNN_FC_HEAD_DIM, config.NUM_CLASS)
if is_training:
# rpn_losses = ..
# fastrcnn loss:
matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)
fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
mrcnn_loss = 0.0
wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(1e-4), name='wd_cost')
total_cost = tf.add_n(rpn_loss_collection + [
fastrcnn_label_loss, fastrcnn_box_loss,
mrcnn_loss, wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
def visualize(model_path, nr_visualize=50, output_dir='output'):
......@@ -241,7 +423,7 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
df.reset_state()
pred = OfflinePredictor(PredictConfig(
model=Model(),
model=ResNetC4Model(),
session_init=get_model_loader(model_path),
input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[
......@@ -361,7 +543,7 @@ if __name__ == '__main__':
visualize(args.load)
else:
pred = OfflinePredictor(PredictConfig(
model=Model(),
model=get_model(),
session_init=get_model_loader(args.load),
input_names=['image'],
output_names=get_model_output_names()))
......@@ -372,7 +554,7 @@ if __name__ == '__main__':
COCODetection(config.BASEDIR, 'val2014') # Only to load the class names into caches
predict(pred, args.predict)
else:
logger.set_logger_dir(args.logdir)
logger.set_logger_dir(args.logdir, 'd')
print_config()
factor = get_batch_factor()
stepnum = config.STEPS_PER_EPOCH
......@@ -387,7 +569,7 @@ if __name__ == '__main__':
(steps * factor // stepnum, config.BASE_LR * mult))
cfg = TrainConfig(
model=Model(),
model=get_model(),
data=QueueInput(get_train_dataflow(add_mask=config.MODE_MASK)),
callbacks=[
PeriodicCallback(
......@@ -402,7 +584,7 @@ if __name__ == '__main__':
EstimatedTimeLeft(),
],
steps_per_epoch=stepnum,
max_epoch=config.LR_SCHEDULE[2] * factor // stepnum,
max_epoch=config.LR_SCHEDULE[-1] * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None,
)
trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment