Commit d2309a1b authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] GN on FPN; FreezeC2

parent ccda3790
...@@ -71,6 +71,7 @@ def backbone_argscope(): ...@@ -71,6 +71,7 @@ def backbone_argscope():
@contextmanager @contextmanager
def maybe_syncbn_scope(): def maybe_syncbn_scope():
if cfg.BACKBONE.NORM == 'SyncBN': if cfg.BACKBONE.NORM == 'SyncBN':
assert cfg.BACKBONE.FREEZE_AT == 2 # TODO add better support
with argscope(BatchNorm, training=None, sync_statistics='nccl'): with argscope(BatchNorm, training=None, sync_statistics='nccl'):
yield yield
else: else:
...@@ -143,7 +144,7 @@ def resnet_group(name, l, block_func, features, count, stride): ...@@ -143,7 +144,7 @@ def resnet_group(name, l, block_func, features, count, stride):
return l return l
def resnet_c4_backbone(image, num_blocks, freeze_c2=True): def resnet_c4_backbone(image, num_blocks):
assert len(num_blocks) == 3 assert len(num_blocks) == 3
with backbone_argscope(): with backbone_argscope():
l = tf.pad(image, [[0, 0], [0, 0], maybe_reverse_pad(2, 3), maybe_reverse_pad(2, 3)]) l = tf.pad(image, [[0, 0], [0, 0], maybe_reverse_pad(2, 3), maybe_reverse_pad(2, 3)])
...@@ -152,7 +153,7 @@ def resnet_c4_backbone(image, num_blocks, freeze_c2=True): ...@@ -152,7 +153,7 @@ def resnet_c4_backbone(image, num_blocks, freeze_c2=True):
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID') l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1) c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
# TODO replace var by const to enable optimization # TODO replace var by const to enable optimization
if freeze_c2: if cfg.BACKBONE.FREEZE_AT == 2:
c2 = tf.stop_gradient(c2) c2 = tf.stop_gradient(c2)
with maybe_syncbn_scope(): with maybe_syncbn_scope():
c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2) c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
...@@ -168,7 +169,7 @@ def resnet_conv5(image, num_block): ...@@ -168,7 +169,7 @@ def resnet_conv5(image, num_block):
return l return l
def resnet_fpn_backbone(image, num_blocks, freeze_c2=True): def resnet_fpn_backbone(image, num_blocks):
shape2d = tf.shape(image)[2:] shape2d = tf.shape(image)[2:]
mult = float(cfg.FPN.RESOLUTION_REQUIREMENT) mult = float(cfg.FPN.RESOLUTION_REQUIREMENT)
new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult) new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult)
...@@ -186,7 +187,7 @@ def resnet_fpn_backbone(image, num_blocks, freeze_c2=True): ...@@ -186,7 +187,7 @@ def resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)]) l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID') l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1) c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
if freeze_c2: if cfg.BACKBONE.FREEZE_AT == 2:
c2 = tf.stop_gradient(c2) c2 = tf.stop_gradient(c2)
with maybe_syncbn_scope(): with maybe_syncbn_scope():
c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2) c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import os import os
from termcolor import colored from termcolor import colored
from tabulate import tabulate from tabulate import tabulate
import tqdm
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.rect import FloatBox from tensorpack.utils.rect import FloatBox
...@@ -96,7 +97,7 @@ class COCODetection(object): ...@@ -96,7 +97,7 @@ class COCODetection(object):
# list of dict, each has keys: height,width,id,file_name # list of dict, each has keys: height,width,id,file_name
imgs = self.coco.loadImgs(img_ids) imgs = self.coco.loadImgs(img_ids)
for img in imgs: for img in tqdm.tqdm(imgs):
self._use_absolute_file_name(img) self._use_absolute_file_name(img)
if add_gt: if add_gt:
self._add_detection_gt(img, add_mask) self._add_detection_gt(img, add_mask)
......
...@@ -43,6 +43,13 @@ class AttrDict(): ...@@ -43,6 +43,13 @@ class AttrDict():
v = eval(v) v = eval(v)
setattr(dic, key, v) setattr(dic, key, v)
# avoid silent bugs
def __eq__(self, _):
raise NotImplementedError()
def __ne__(self, _):
raise NotImplementedError()
config = AttrDict() config = AttrDict()
_C = config # short alias to avoid coding _C = config # short alias to avoid coding
...@@ -65,6 +72,7 @@ _C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50 ...@@ -65,6 +72,7 @@ _C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101 # RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
_C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers _C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers
_C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN _C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN
_C.BACKBONE.FREEZE_AT = 2 # options: 0, 2
# Use a base model with TF-preferred padding mode, # Use a base model with TF-preferred padding mode,
# which may pad more pixels on right/bottom than top/left. # which may pad more pixels on right/bottom than top/left.
...@@ -131,6 +139,7 @@ _C.FRCNN.FG_RATIO = 0.25 # fg ratio in a ROI batch ...@@ -131,6 +139,7 @@ _C.FRCNN.FG_RATIO = 0.25 # fg ratio in a ROI batch
_C.FPN.ANCHOR_STRIDES = (4, 8, 16, 32, 64) # strides for each FPN level. Must be the same length as ANCHOR_SIZES _C.FPN.ANCHOR_STRIDES = (4, 8, 16, 32, 64) # strides for each FPN level. Must be the same length as ANCHOR_SIZES
_C.FPN.PROPOSAL_MODE = 'Level' # 'Level', 'Joint' _C.FPN.PROPOSAL_MODE = 'Level' # 'Level', 'Joint'
_C.FPN.NUM_CHANNEL = 256 _C.FPN.NUM_CHANNEL = 256
_C.FPN.NORM = 'None' # 'None', 'GN'
# conv head and fc head are only used in FPN. # conv head and fc head are only used in FPN.
# For C4 models, the head is C5 # For C4 models, the head is C5
_C.FPN.FRCNN_HEAD_FUNC = 'fastrcnn_2fc_head' _C.FPN.FRCNN_HEAD_FUNC = 'fastrcnn_2fc_head'
...@@ -159,6 +168,7 @@ def finalize_configs(is_training): ...@@ -159,6 +168,7 @@ def finalize_configs(is_training):
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN': if _C.BACKBONE.NORM != 'FreezeBN':
assert not _C.BACKBONE.FREEZE_AFFINE assert not _C.BACKBONE.FREEZE_AFFINE
assert _C.BACKBONE.FREEZE_AT in [0, 2]
_C.RPN.NUM_ANCHOR = len(_C.RPN.ANCHOR_SIZES) * len(_C.RPN.ANCHOR_RATIOS) _C.RPN.NUM_ANCHOR = len(_C.RPN.ANCHOR_SIZES) * len(_C.RPN.ANCHOR_RATIOS)
assert len(_C.FPN.ANCHOR_STRIDES) == len(_C.RPN.ANCHOR_SIZES) assert len(_C.FPN.ANCHOR_STRIDES) == len(_C.RPN.ANCHOR_SIZES)
...@@ -171,6 +181,7 @@ def finalize_configs(is_training): ...@@ -171,6 +181,7 @@ def finalize_configs(is_training):
assert _C.FPN.PROPOSAL_MODE in ['Level', 'Joint'] assert _C.FPN.PROPOSAL_MODE in ['Level', 'Joint']
assert _C.FPN.FRCNN_HEAD_FUNC.endswith('_head') assert _C.FPN.FRCNN_HEAD_FUNC.endswith('_head')
assert _C.FPN.MRCNN_HEAD_FUNC.endswith('_head') assert _C.FPN.MRCNN_HEAD_FUNC.endswith('_head')
assert _C.FPN.NORM in ['None', 'GN']
if is_training: if is_training:
os.environ['TF_AUTOTUNE_THRESHOLD'] = '1' os.environ['TF_AUTOTUNE_THRESHOLD'] = '1'
......
...@@ -137,7 +137,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes): ...@@ -137,7 +137,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
# cand_inds = np.where(anchor_labels >= 0)[0] # cand_inds = np.where(anchor_labels >= 0)[0]
# cand_anchors = anchors[cand_inds] # cand_anchors = anchors[cand_inds]
# ious = np_iou(cand_anchors, crowd_boxes) # ious = np_iou(cand_anchors, crowd_boxes)
# overlap_with_crowd = cand_inds[ious.max(axis=1) > cfg.RPN.CROWD_OVERLAP_THRESH] # overlap_with_crowd = cand_inds[ious.max(axis=1) > cfg.RPN.CROWD_OVERLAP_THRES]
# anchor_labels[overlap_with_crowd] = -1 # anchor_labels[overlap_with_crowd] = -1
# Subsample fg labels: ignore some fg if fg is too many # Subsample fg labels: ignore some fg if fg is too many
......
...@@ -15,6 +15,7 @@ from model_rpn import rpn_losses, generate_rpn_proposals ...@@ -15,6 +15,7 @@ from model_rpn import rpn_losses, generate_rpn_proposals
from model_box import roi_align from model_box import roi_align
from utils.box_ops import area as tf_area from utils.box_ops import area as tf_area
from config import config as cfg from config import config as cfg
from basemodel import GroupNorm
@layer_register(log_shape=True) @layer_register(log_shape=True)
...@@ -29,6 +30,8 @@ def fpn_model(features): ...@@ -29,6 +30,8 @@ def fpn_model(features):
assert len(features) == 4, features assert len(features) == 4, features
num_channel = cfg.FPN.NUM_CHANNEL num_channel = cfg.FPN.NUM_CHANNEL
use_gn = cfg.FPN.NORM == 'GN'
def upsample2x(name, x): def upsample2x(name, x):
return FixedUnPooling( return FixedUnPooling(
name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'), name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'),
...@@ -47,6 +50,8 @@ def fpn_model(features): ...@@ -47,6 +50,8 @@ def fpn_model(features):
kernel_initializer=tf.variance_scaling_initializer(scale=1.)): kernel_initializer=tf.variance_scaling_initializer(scale=1.)):
lat_2345 = [Conv2D('lateral_1x1_c{}'.format(i + 2), c, num_channel, 1) lat_2345 = [Conv2D('lateral_1x1_c{}'.format(i + 2), c, num_channel, 1)
for i, c in enumerate(features)] for i, c in enumerate(features)]
if use_gn:
lat_2345 = [GroupNorm('gn_c{}'.format(i + 2), c) for i, c in enumerate(lat_2345)]
lat_sum_5432 = [] lat_sum_5432 = []
for idx, lat in enumerate(lat_2345[::-1]): for idx, lat in enumerate(lat_2345[::-1]):
if idx == 0: if idx == 0:
...@@ -56,6 +61,8 @@ def fpn_model(features): ...@@ -56,6 +61,8 @@ def fpn_model(features):
lat_sum_5432.append(lat) lat_sum_5432.append(lat)
p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3) p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3)
for i, c in enumerate(lat_sum_5432[::-1])] for i, c in enumerate(lat_sum_5432[::-1])]
if use_gn:
p2345 = [GroupNorm('gn_p{}'.format(i + 2), c) for i, c in enumerate(p2345)]
p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2, data_format='channels_first', padding='VALID') p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2, data_format='channels_first', padding='VALID')
return p2345 + [p6] return p2345 + [p6]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment