Commit 940d4167 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] remove cfg.DATA.NUM_CLASS

parent 8ab0d4a6
......@@ -216,7 +216,6 @@ def finalize_configs(is_training):
Run some sanity checks, and populate some configs from others
"""
_C.freeze(False) # populate new keys now
_C.DATA.NUM_CLASS = _C.DATA.NUM_CATEGORY + 1 # +1 background
_C.DATA.BASEDIR = os.path.expanduser(_C.DATA.BASEDIR)
if isinstance(_C.DATA.VAL, six.string_types): # support single string (the typical case) as well
_C.DATA.VAL = (_C.DATA.VAL, )
......
......@@ -137,7 +137,7 @@ class ResNetC4Model(GeneralizedRCNN):
feature_fastrcnn = resnet_conv5(roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCKS[-1]) # nxcx7x7
# Keep C5 feature to be shared with mask branch
feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CATEGORY)
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits, gt_boxes,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
......@@ -254,7 +254,7 @@ class ResNetFPNModel(GeneralizedRCNN):
head_feature = fastrcnn_head_func('fastrcnn', roi_feature_fastrcnn)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CLASS)
'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CATEGORY)
fastrcnn_head = FastRCNNHead(proposals, fastrcnn_box_logits, fastrcnn_label_logits,
gt_boxes, tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32))
else:
......@@ -263,7 +263,7 @@ class ResNetFPNModel(GeneralizedRCNN):
fastrcnn_head = CascadeRCNNHead(
proposals, roi_func, fastrcnn_head_func,
(gt_boxes, gt_labels), image_shape2d, cfg.DATA.NUM_CLASS)
(gt_boxes, gt_labels), image_shape2d, cfg.DATA.NUM_CATEGORY)
if self.training:
all_losses = fastrcnn_head.losses()
......
......@@ -10,7 +10,8 @@ from utils.box_ops import pairwise_iou
class CascadeRCNNHead(object):
def __init__(self, proposals,
roi_func, fastrcnn_head_func, gt_targets, image_shape2d, num_classes):
roi_func, fastrcnn_head_func, gt_targets, image_shape2d,
num_categories):
"""
Args:
proposals: BoxProposals
......@@ -66,7 +67,7 @@ class CascadeRCNNHead(object):
pooled_feature = self.scale_gradient(pooled_feature)
head_feature = self.fastrcnn_head_func('head', pooled_feature)
label_logits, box_logits = fastrcnn_outputs(
'outputs', head_feature, self.num_classes, class_agnostic_regression=True)
'outputs', head_feature, self.num_categories, class_agnostic_regression=True)
head = FastRCNNHead(proposals, box_logits, label_logits, self.gt_boxes, reg_weights)
refined_boxes = head.decoded_output_boxes_class_agnostic()
......@@ -107,7 +108,7 @@ class CascadeRCNNHead(object):
"""
ret = self._cascade_boxes[-1]
ret = tf.expand_dims(ret, 1) # class-agnostic
return tf.tile(ret, [1, self.num_classes, 1])
return tf.tile(ret, [1, self.num_categories + 1, 1])
def output_scores(self, name=None):
"""
......
......@@ -102,17 +102,18 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
@layer_register(log_shape=True)
def fastrcnn_outputs(feature, num_classes, class_agnostic_regression=False):
def fastrcnn_outputs(feature, num_categories, class_agnostic_regression=False):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
num_categories (int):
class_agnostic_regression (bool): if True, regression to N x 1 x 4
Returns:
cls_logits: N x num_class classification logits
reg_logits: N x num_classx4 or Nx2x4 if class agnostic
"""
num_classes = num_categories + 1
classification = FullyConnected(
'class', feature, num_classes,
kernel_initializer=tf.random_normal_initializer(stddev=0.01))
......@@ -186,8 +187,7 @@ def fastrcnn_predictions(boxes, scores):
scores: K
labels: K
"""
assert boxes.shape[1] == cfg.DATA.NUM_CLASS
assert scores.shape[1] == cfg.DATA.NUM_CLASS
assert boxes.shape[1] == scores.shape[1]
boxes = tf.transpose(boxes, [1, 0, 2])[1:, :, :] # #catxnx4
scores = tf.transpose(scores[:, 1:], [1, 0]) # #catxn
......@@ -353,6 +353,7 @@ class FastRCNNHead(object):
if k != 'self' and v is not None:
setattr(self, k, v)
self._bbox_class_agnostic = int(box_logits.shape[1]) == 1
self._num_classes = box_logits.shape[1]
@memoized_method
def fg_box_logits(self):
......@@ -373,7 +374,7 @@ class FastRCNNHead(object):
def decoded_output_boxes(self):
""" Returns: N x #class x 4 """
anchors = tf.tile(tf.expand_dims(self.proposals.boxes, 1),
[1, cfg.DATA.NUM_CLASS, 1]) # N x #class x 4
[1, self._num_classes, 1]) # N x #class x 4
decoded_boxes = decode_bbox_target(
self.box_logits / self.bbox_regression_weights,
anchors
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment