Commit b58b3a78 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] Add GN in backbone

parent 843990f5
......@@ -3,11 +3,11 @@
from contextlib import contextmanager
import tensorflow as tf
from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.tfutils import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.varreplace import custom_getter_scope
from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, BNReLU, layer_register)
Conv2D, MaxPooling, BatchNorm, layer_register)
from config import config as cfg
......@@ -56,9 +56,13 @@ def maybe_reverse_pad(topleft, bottomright):
@contextmanager
def resnet_argscope():
def backbone_argscope():
def nonlin(x):
x = get_norm()(x)
return tf.nn.relu(x)
with argscope([Conv2D, MaxPooling, BatchNorm], data_format='channels_first'), \
argscope(Conv2D, use_bias=False), \
argscope(Conv2D, use_bias=False, activation=nonlin), \
argscope(BatchNorm, training=False), \
custom_getter_scope(maybe_freeze_affine):
yield
......@@ -89,23 +93,23 @@ def image_preprocess(image, bgr=True):
return image
def get_bn(zero_init=False):
if zero_init:
return lambda x, name=None: BatchNorm('bn', x, gamma_init=tf.zeros_initializer())
def get_norm(zero_init=False):
if cfg.BACKBONE.NORM == 'GN':
Norm = GroupNorm
layer_name = 'gn'
else:
return lambda x, name=None: BatchNorm('bn', x)
Norm = BatchNorm
layer_name = 'bn'
return lambda x: Norm(layer_name, x, gamma_initializer=tf.zeros_initializer() if zero_init else None)
def resnet_shortcut(l, n_out, stride, activation=tf.identity):
data_format = get_arg_scope()['Conv2D']['data_format']
n_in = l.get_shape().as_list()[1 if data_format in ['NCHW', 'channels_first'] else 3]
n_in = l.shape[1]
if n_in != n_out: # change dimension when channel is not the same
# TF's SAME mode output ceil(x/stride), which is NOT what we want when x is odd and stride is 2
# In FPN mode, the images are pre-padded already.
if not cfg.MODE_FPN and stride == 2:
l = l[:, :, :-1, :-1]
return Conv2D('convshortcut', l, n_out, 1,
strides=stride, padding='VALID', activation=activation)
else:
return Conv2D('convshortcut', l, n_out, 1,
strides=stride, activation=activation)
else:
......@@ -117,17 +121,17 @@ def resnet_bottleneck(l, ch_out, stride):
if cfg.BACKBONE.STRIDE_1X1:
if stride == 2:
l = l[:, :, :-1, :-1]
l = Conv2D('conv1', l, ch_out, 1, strides=stride, activation=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, strides=1, activation=BNReLU)
l = Conv2D('conv1', l, ch_out, 1, strides=stride)
l = Conv2D('conv2', l, ch_out, 3, strides=1)
else:
l = Conv2D('conv1', l, ch_out, 1, strides=1, activation=BNReLU)
l = Conv2D('conv1', l, ch_out, 1, strides=1)
if stride == 2:
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = Conv2D('conv2', l, ch_out, 3, strides=2, activation=BNReLU, padding='VALID')
l = Conv2D('conv2', l, ch_out, 3, strides=2, padding='VALID')
else:
l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True))
ret = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False))
l = Conv2D('conv2', l, ch_out, 3, strides=stride)
l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_norm(zero_init=True))
ret = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_norm(zero_init=False))
return tf.nn.relu(ret, name='output')
......@@ -141,9 +145,9 @@ def resnet_group(name, l, block_func, features, count, stride):
def resnet_c4_backbone(image, num_blocks, freeze_c2=True):
assert len(num_blocks) == 3
with resnet_argscope():
with backbone_argscope():
l = tf.pad(image, [[0, 0], [0, 0], maybe_reverse_pad(2, 3), maybe_reverse_pad(2, 3)])
l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID')
l = Conv2D('conv0', l, 64, 7, strides=2, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
......@@ -159,7 +163,7 @@ def resnet_c4_backbone(image, num_blocks, freeze_c2=True):
@auto_reuse_variable_scope
def resnet_conv5(image, num_block):
with resnet_argscope(), maybe_syncbn_scope():
with backbone_argscope(), maybe_syncbn_scope():
l = resnet_group('group3', image, resnet_bottleneck, 512, num_block, 2)
return l
......@@ -170,7 +174,7 @@ def resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult)
pad_shape2d = new_shape2d - shape2d
assert len(num_blocks) == 4, num_blocks
with resnet_argscope():
with backbone_argscope():
chan = image.shape[1]
pad_base = maybe_reverse_pad(2, 3)
l = tf.pad(image, tf.stack(
......@@ -178,7 +182,7 @@ def resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
[pad_base[0], pad_base[1] + pad_shape2d[0]],
[pad_base[0], pad_base[1] + pad_shape2d[1]]]))
l.set_shape([None, chan, None, None])
l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID')
l = Conv2D('conv0', l, 64, 7, strides=2, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
......
......@@ -132,7 +132,8 @@ _C.FPN.PROPOSAL_MODE = 'Level' # 'Level', 'Joint'
_C.FPN.NUM_CHANNEL = 256
# conv head and fc head are only used in FPN.
# For C4 models, the head is C5
_C.FPN.FRCNN_HEAD_FUNC = 'fastrcnn_2fc_head' # choices: fastrcnn_2fc_head, fastrcnn_4conv1fc_head
_C.FPN.FRCNN_HEAD_FUNC = 'fastrcnn_2fc_head'
# choices: fastrcnn_2fc_head, fastrcnn_4conv1fc_head, fastrcnn_4conv1fc_gn_head
_C.FPN.FRCNN_CONV_HEAD_DIM = 256
_C.FPN.FRCNN_FC_HEAD_DIM = 1024
......@@ -152,7 +153,7 @@ def finalize_configs(is_training):
"""
_C.DATA.NUM_CLASS = _C.DATA.NUM_CATEGORY + 1 # +1 background
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN'], _C.BACKBONE.NORM
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN':
assert not _C.BACKBONE.FREEZE_AFFINE
......
......@@ -56,7 +56,7 @@ def fpn_model(features):
lat_sum_5432.append(lat)
p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3)
for i, c in enumerate(lat_sum_5432[::-1])]
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2, data_format='channels_first')
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2, data_format='channels_first', padding='VALID')
return p2345 + [p6]
......
......@@ -243,7 +243,7 @@ class ResNetC4Model(DetectionModel):
wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
return total_cost * (1. / cfg.TRAIN.NUM_GPUS)
else:
final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
......@@ -379,7 +379,7 @@ class ResNetFPNModel(DetectionModel):
mrcnn_loss, wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
return total_cost * (1. / cfg.TRAIN.NUM_GPUS)
else:
final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
......@@ -595,8 +595,8 @@ if __name__ == '__main__':
)
if is_horovod:
# horovod mode has the best speed for this model
trainer = HorovodTrainer()
trainer = HorovodTrainer(average=False)
else:
# nccl mode has better speed than cpu mode
trainer = SyncMultiGPUTrainerReplicated(cfg.TRAIN.NUM_GPUS, mode='nccl')
trainer = SyncMultiGPUTrainerReplicated(cfg.TRAIN.NUM_GPUS, average=False, mode='nccl')
launch_train_with_config(traincfg, trainer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment