Commit d2309a1b authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] GN on FPN; FreezeC2

parent ccda3790
......@@ -71,6 +71,7 @@ def backbone_argscope():
@contextmanager
def maybe_syncbn_scope():
if cfg.BACKBONE.NORM == 'SyncBN':
assert cfg.BACKBONE.FREEZE_AT == 2 # TODO add better support
with argscope(BatchNorm, training=None, sync_statistics='nccl'):
yield
else:
......@@ -143,7 +144,7 @@ def resnet_group(name, l, block_func, features, count, stride):
return l
def resnet_c4_backbone(image, num_blocks, freeze_c2=True):
def resnet_c4_backbone(image, num_blocks):
assert len(num_blocks) == 3
with backbone_argscope():
l = tf.pad(image, [[0, 0], [0, 0], maybe_reverse_pad(2, 3), maybe_reverse_pad(2, 3)])
......@@ -152,7 +153,7 @@ def resnet_c4_backbone(image, num_blocks, freeze_c2=True):
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
# TODO replace var by const to enable optimization
if freeze_c2:
if cfg.BACKBONE.FREEZE_AT == 2:
c2 = tf.stop_gradient(c2)
with maybe_syncbn_scope():
c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
......@@ -168,7 +169,7 @@ def resnet_conv5(image, num_block):
return l
def resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
def resnet_fpn_backbone(image, num_blocks):
shape2d = tf.shape(image)[2:]
mult = float(cfg.FPN.RESOLUTION_REQUIREMENT)
new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult)
......@@ -186,7 +187,7 @@ def resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
if freeze_c2:
if cfg.BACKBONE.FREEZE_AT == 2:
c2 = tf.stop_gradient(c2)
with maybe_syncbn_scope():
c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
......
......@@ -5,6 +5,7 @@ import numpy as np
import os
from termcolor import colored
from tabulate import tabulate
import tqdm
from tensorpack.utils import logger
from tensorpack.utils.rect import FloatBox
......@@ -96,7 +97,7 @@ class COCODetection(object):
# list of dict, each has keys: height,width,id,file_name
imgs = self.coco.loadImgs(img_ids)
for img in imgs:
for img in tqdm.tqdm(imgs):
self._use_absolute_file_name(img)
if add_gt:
self._add_detection_gt(img, add_mask)
......
......@@ -43,6 +43,13 @@ class AttrDict():
v = eval(v)
setattr(dic, key, v)
# avoid silent bugs
def __eq__(self, _):
raise NotImplementedError()
def __ne__(self, _):
raise NotImplementedError()
config = AttrDict()
_C = config # short alias to avoid coding
......@@ -65,6 +72,7 @@ _C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
_C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers
_C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN
_C.BACKBONE.FREEZE_AT = 2 # options: 0, 2
# Use a base model with TF-preferred padding mode,
# which may pad more pixels on right/bottom than top/left.
......@@ -131,6 +139,7 @@ _C.FRCNN.FG_RATIO = 0.25 # fg ratio in a ROI batch
_C.FPN.ANCHOR_STRIDES = (4, 8, 16, 32, 64) # strides for each FPN level. Must be the same length as ANCHOR_SIZES
_C.FPN.PROPOSAL_MODE = 'Level' # 'Level', 'Joint'
_C.FPN.NUM_CHANNEL = 256
_C.FPN.NORM = 'None' # 'None', 'GN'
# conv head and fc head are only used in FPN.
# For C4 models, the head is C5
_C.FPN.FRCNN_HEAD_FUNC = 'fastrcnn_2fc_head'
......@@ -159,6 +168,7 @@ def finalize_configs(is_training):
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN':
assert not _C.BACKBONE.FREEZE_AFFINE
assert _C.BACKBONE.FREEZE_AT in [0, 2]
_C.RPN.NUM_ANCHOR = len(_C.RPN.ANCHOR_SIZES) * len(_C.RPN.ANCHOR_RATIOS)
assert len(_C.FPN.ANCHOR_STRIDES) == len(_C.RPN.ANCHOR_SIZES)
......@@ -171,6 +181,7 @@ def finalize_configs(is_training):
assert _C.FPN.PROPOSAL_MODE in ['Level', 'Joint']
assert _C.FPN.FRCNN_HEAD_FUNC.endswith('_head')
assert _C.FPN.MRCNN_HEAD_FUNC.endswith('_head')
assert _C.FPN.NORM in ['None', 'GN']
if is_training:
os.environ['TF_AUTOTUNE_THRESHOLD'] = '1'
......
......@@ -137,7 +137,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
# cand_inds = np.where(anchor_labels >= 0)[0]
# cand_anchors = anchors[cand_inds]
# ious = np_iou(cand_anchors, crowd_boxes)
# overlap_with_crowd = cand_inds[ious.max(axis=1) > cfg.RPN.CROWD_OVERLAP_THRESH]
# overlap_with_crowd = cand_inds[ious.max(axis=1) > cfg.RPN.CROWD_OVERLAP_THRES]
# anchor_labels[overlap_with_crowd] = -1
# Subsample fg labels: ignore some fg if fg is too many
......
......@@ -15,6 +15,7 @@ from model_rpn import rpn_losses, generate_rpn_proposals
from model_box import roi_align
from utils.box_ops import area as tf_area
from config import config as cfg
from basemodel import GroupNorm
@layer_register(log_shape=True)
......@@ -29,6 +30,8 @@ def fpn_model(features):
assert len(features) == 4, features
num_channel = cfg.FPN.NUM_CHANNEL
use_gn = cfg.FPN.NORM == 'GN'
def upsample2x(name, x):
return FixedUnPooling(
name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'),
......@@ -47,6 +50,8 @@ def fpn_model(features):
kernel_initializer=tf.variance_scaling_initializer(scale=1.)):
lat_2345 = [Conv2D('lateral_1x1_c{}'.format(i + 2), c, num_channel, 1)
for i, c in enumerate(features)]
if use_gn:
lat_2345 = [GroupNorm('gn_c{}'.format(i + 2), c) for i, c in enumerate(lat_2345)]
lat_sum_5432 = []
for idx, lat in enumerate(lat_2345[::-1]):
if idx == 0:
......@@ -56,6 +61,8 @@ def fpn_model(features):
lat_sum_5432.append(lat)
p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3)
for i, c in enumerate(lat_sum_5432[::-1])]
if use_gn:
p2345 = [GroupNorm('gn_p{}'.format(i + 2), c) for i, c in enumerate(p2345)]
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2, data_format='channels_first', padding='VALID')
return p2345 + [p6]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment