Commit cf97218c authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] move fastrcnn outputs out of head. support class-agnostic regression

parent 0d36de5f
......@@ -100,22 +100,25 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
@layer_register(log_shape=True)
def fastrcnn_outputs(feature, num_classes):
def fastrcnn_outputs(feature, num_classes, class_agnostic_regression=False):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
class_agnostic_regression (bool): if True, regression to N x 1 x 4
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class x 4)
cls_logits: N x num_class classification logits
reg_logits: N x num_classx4 or Nx2x4 if class agnostic
"""
classification = FullyConnected(
'class', feature, num_classes,
kernel_initializer=tf.random_normal_initializer(stddev=0.01))
num_classes_for_box = 1 if class_agnostic_regression else num_classes
box_regression = FullyConnected(
'box', feature, num_classes * 4,
'box', feature, num_classes_for_box * 4,
kernel_initializer=tf.random_normal_initializer(stddev=0.001))
box_regression = tf.reshape(box_regression, (-1, num_classes, 4), name='output_box')
box_regression = tf.reshape(box_regression, (-1, num_classes_for_box, 4), name='output_box')
return classification, box_regression
......@@ -126,7 +129,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
labels: n,
label_logits: nxC
fg_boxes: nfgx4, encoded
fg_box_logits: nfgxCx4
fg_box_logits: nfgxCx4 or nfgx1x4 if class agnostic
Returns:
label_loss, box_loss
......@@ -138,9 +141,12 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
fg_inds = tf.where(labels > 0)[:, 0]
fg_labels = tf.gather(labels, fg_inds)
num_fg = tf.size(fg_inds, out_type=tf.int64)
indices = tf.stack(
[tf.range(num_fg), fg_labels], axis=1) # #fgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices)
if int(fg_box_logits.shape[1]) > 1:
indices = tf.stack(
[tf.range(num_fg), fg_labels], axis=1) # #fgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices)
else:
fg_box_logits = tf.reshape(fg_box_logits, [-1, 4])
with tf.name_scope('label_metrics'), tf.device('/cpu:0'):
prediction = tf.argmax(label_logits, axis=1, name='label_prediction')
......@@ -229,24 +235,23 @@ FastRCNN heads for FPN:
@layer_register(log_shape=True)
def fastrcnn_2fc_head(feature, num_classes):
def fastrcnn_2fc_head(feature):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
Returns:
outputs of `fastrcnn_outputs()`
2D head feature
"""
dim = cfg.FPN.FRCNN_FC_HEAD_DIM
init = tf.variance_scaling_initializer()
hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, activation=tf.nn.relu)
hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, activation=tf.nn.relu)
return fastrcnn_outputs('outputs', hidden, num_classes)
return hidden
@layer_register(log_shape=True)
def fastrcnn_Xconv1fc_head(feature, num_classes, num_convs, norm=None):
def fastrcnn_Xconv1fc_head(feature, num_convs, norm=None):
"""
Args:
feature (NCHW):
......@@ -255,7 +260,7 @@ def fastrcnn_Xconv1fc_head(feature, num_classes, num_convs, norm=None):
norm (str or None): either None or 'GN'
Returns:
outputs of `fastrcnn_outputs()`
2D head feature
"""
assert norm in [None, 'GN'], norm
l = feature
......@@ -268,7 +273,7 @@ def fastrcnn_Xconv1fc_head(feature, num_classes, num_convs, norm=None):
l = GroupNorm('gn{}'.format(k), l)
l = FullyConnected('fc', l, cfg.FPN.FRCNN_FC_HEAD_DIM,
kernel_initializer=tf.variance_scaling_initializer(), activation=tf.nn.relu)
return fastrcnn_outputs('outputs', l, num_classes)
return l
def fastrcnn_4conv1fc_head(*args, **kwargs):
......@@ -288,10 +293,10 @@ class FastRCNNHead(object):
"""
Args:
input_boxes: Nx4, inputs to the head
box_logits: Nx#classx4, the output of the head
box_logits: Nx#classx4 or Nx1x4, the output of the head
label_logits: Nx#class, the output of the head
bbox_regression_weights: a 4 element tensor
labels: N, each in [0, #class-1], the true label for each input box
labels: N, each in [0, #class), the true label for each input box
matched_gt_boxes_per_fg: #fgx4, the matching gt boxes for each fg input box
The last two arguments could be None when not training.
......@@ -299,10 +304,12 @@ class FastRCNNHead(object):
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self._bbox_class_agnostic = int(box_logits.shape[1]) == 1
@memoized
def fg_inds_in_inputs(self):
""" Returns: #fg indices in [0, N-1] """
assert self.labels is not None
return tf.reshape(tf.where(self.labels > 0), [-1], name='fg_inds_in_inputs')
@memoized
......@@ -312,7 +319,7 @@ class FastRCNNHead(object):
@memoized
def fg_box_logits(self):
""" Returns: #fg x #class x 4 """
""" Returns: #fg x ? x 4 """
return tf.gather(self.box_logits, self.fg_inds_in_inputs(), name='fg_box_logits')
@memoized
......@@ -344,9 +351,20 @@ class FastRCNNHead(object):
@memoized
def decoded_output_boxes_for_true_label(self):
""" Returns: Nx4 decoded boxes """
assert self.labels is not None
return self._decoded_output_boxes_for_label(self.labels)
@memoized
def decoded_output_boxes_for_predicted_label(self):
""" Returns: Nx4 decoded boxes """
return self._decoded_output_boxes_for_label(self.predicted_labels())
@memoized
def decoded_output_boxes_for_label(self, labels):
assert not self._bbox_class_agnostic
indices = tf.stack([
tf.range(tf.size(self.labels, out_type=tf.int64)),
self.labels
labels
])
needed_logits = tf.gather_nd(self.box_logits, indices)
decoded = decode_bbox_target(
......@@ -355,7 +373,22 @@ class FastRCNNHead(object):
)
return decoded
@memoized
def decoded_output_boxes_class_agnostic(self):
assert self._bbox_class_agnostic
box_logits = tf.reshape(self.box_logits, [-1, 4])
decoded = decode_bbox_target(
box_logits / self.bbox_regression_weights,
self.input_boxes
)
return decoded
@memoized
def output_scores(self, name=None):
""" Returns: N x #class scores, summed to one for each box."""
return tf.nn.softmax(self.label_logits, name=name)
@memoized
def predicted_labels(self):
""" Returns: N ints """
return tf.argmax(self.label_logits, axis=1, name='predicted_labels')
......@@ -290,8 +290,9 @@ class ResNetFPNModel(DetectionModel):
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)
fastrcnn_head_func = getattr(model_frcnn, cfg.FPN.FRCNN_HEAD_FUNC)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func(
'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS)
head_feature = fastrcnn_head_func('fastrcnn', roi_feature_fastrcnn)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
'fastrcnn/outputs', head_feature, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32),
rcnn_labels, matched_gt_boxes)
......
......@@ -221,16 +221,16 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
if sync_statistics == 'nccl':
if six.PY3 and TF_version <= (1, 9) and ctx.is_main_training_tower:
logger.warn("A bug in TensorFlow<=1.9 will cause cross-GPU BatchNorm to fail. "
"Upgrade or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360")
from tensorflow.contrib.nccl.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
num_dev = ctx.total
if num_dev == 1:
logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!")
else:
assert six.PY2 or TF_version >= (1, 10), \
"Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
from tensorflow.contrib.nccl.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
batch_mean = gen_nccl_ops.nccl_all_reduce(
input=batch_mean,
reduction='sum',
......@@ -243,13 +243,14 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
elif sync_statistics == 'horovod':
# Require https://github.com/uber/horovod/pull/331
import horovod
hvd_version = tuple(map(int, horovod.__version__.split('.')))
assert hvd_version >= (0, 13, 6), "sync_statistics needs horovod>=0.13.6 !"
import horovod.tensorflow as hvd
if hvd.size() == 1:
logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
else:
import horovod
hvd_version = tuple(map(int, horovod.__version__.split('.')))
assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean = hvd.allreduce(batch_mean, average=True)
batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
batch_var = batch_mean_square - tf.square(batch_mean)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment