Commit 9b1d1095 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use attrdict for config

parent 4f1efe74
...@@ -9,20 +9,20 @@ from tensorpack.tfutils.varreplace import custom_getter_scope ...@@ -9,20 +9,20 @@ from tensorpack.tfutils.varreplace import custom_getter_scope
from tensorpack.models import ( from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, BNReLU) Conv2D, MaxPooling, BatchNorm, BNReLU)
import config from config import config as cfg
def maybe_freeze_affine(getter, *args, **kwargs): def maybe_freeze_affine(getter, *args, **kwargs):
# custom getter to freeze affine params inside bn # custom getter to freeze affine params inside bn
name = args[0] if len(args) else kwargs.get('name') name = args[0] if len(args) else kwargs.get('name')
if name.endswith('/gamma') or name.endswith('/beta'): if name.endswith('/gamma') or name.endswith('/beta'):
if config.FREEZE_AFFINE: if cfg.BACKBONE.FREEZE_AFFINE:
kwargs['trainable'] = False kwargs['trainable'] = False
return getter(*args, **kwargs) return getter(*args, **kwargs)
def maybe_reverse_pad(topleft, bottomright): def maybe_reverse_pad(topleft, bottomright):
if config.TF_PAD_MODE: if cfg.BACKBONE.TF_PAD_MODE:
return [topleft, bottomright] return [topleft, bottomright]
return [bottomright, topleft] return [bottomright, topleft]
...@@ -65,7 +65,7 @@ def resnet_shortcut(l, n_out, stride, activation=tf.identity): ...@@ -65,7 +65,7 @@ def resnet_shortcut(l, n_out, stride, activation=tf.identity):
n_in = l.get_shape().as_list()[1 if data_format in ['NCHW', 'channels_first'] else 3] n_in = l.get_shape().as_list()[1 if data_format in ['NCHW', 'channels_first'] else 3]
if n_in != n_out: # change dimension when channel is not the same if n_in != n_out: # change dimension when channel is not the same
# TF's SAME mode output ceil(x/stride), which is NOT what we want when x is odd and stride is 2 # TF's SAME mode output ceil(x/stride), which is NOT what we want when x is odd and stride is 2
if not config.MODE_FPN and stride == 2: if not cfg.MODE_FPN and stride == 2:
l = l[:, :, :-1, :-1] l = l[:, :, :-1, :-1]
return Conv2D('convshortcut', l, n_out, 1, return Conv2D('convshortcut', l, n_out, 1,
strides=stride, padding='VALID', activation=activation) strides=stride, padding='VALID', activation=activation)
...@@ -124,7 +124,7 @@ def resnet_conv5(image, num_block): ...@@ -124,7 +124,7 @@ def resnet_conv5(image, num_block):
def resnet_fpn_backbone(image, num_blocks, freeze_c2=True): def resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
shape2d = tf.shape(image)[2:] shape2d = tf.shape(image)[2:]
mult = float(config.FPN_RESOLUTION_REQUIREMENT) mult = float(cfg.FPN.RESOLUTION_REQUIREMENT)
new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult) new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult)
pad_shape2d = new_shape2d - shape2d pad_shape2d = new_shape2d - shape2d
assert len(num_blocks) == 4, num_blocks assert len(num_blocks) == 4, num_blocks
......
...@@ -12,13 +12,13 @@ from tensorpack.utils.timer import timed_operation ...@@ -12,13 +12,13 @@ from tensorpack.utils.timer import timed_operation
from tensorpack.utils.argtools import log_once from tensorpack.utils.argtools import log_once
from pycocotools.coco import COCO from pycocotools.coco import COCO
import config from config import config as cfg
__all__ = ['COCODetection', 'COCOMeta'] __all__ = ['COCODetection', 'COCOMeta']
COCO_NUM_CATEGORY = 80 COCO_NUM_CATEGORY = 80
config.NUM_CLASS = COCO_NUM_CATEGORY + 1 cfg.DATA.NUM_CLASS = COCO_NUM_CATEGORY + 1
class _COCOMeta(object): class _COCOMeta(object):
...@@ -48,7 +48,7 @@ class _COCOMeta(object): ...@@ -48,7 +48,7 @@ class _COCOMeta(object):
v: i + 1 for i, v in enumerate(cat_ids)} v: i + 1 for i, v in enumerate(cat_ids)}
self.class_id_to_category_id = { self.class_id_to_category_id = {
v: k for k, v in self.category_id_to_class_id.items()} v: k for k, v in self.category_id_to_class_id.items()}
config.CLASS_NAMES = self.class_names cfg.DATA.CLASS_NAMES = self.class_names
COCOMeta = _COCOMeta() COCOMeta = _COCOMeta()
...@@ -200,7 +200,7 @@ class COCODetection(object): ...@@ -200,7 +200,7 @@ class COCODetection(object):
if __name__ == '__main__': if __name__ == '__main__':
c = COCODetection(config.BASEDIR, 'train2014') c = COCODetection(cfg.DATA.BASEDIR, 'train2014')
gt_boxes = c.load(add_gt=True, add_mask=True) gt_boxes = c.load(add_gt=True, add_mask=True)
print("#Images:", len(gt_boxes)) print("#Images:", len(gt_boxes))
c.print_class_histogram(gt_boxes) c.print_class_histogram(gt_boxes)
...@@ -2,17 +2,13 @@ ...@@ -2,17 +2,13 @@
# File: common.py # File: common.py
import numpy as np import numpy as np
import six
import cv2 import cv2
from tensorpack.dataflow import RNGDataFlow from tensorpack.dataflow import RNGDataFlow
from tensorpack.dataflow.imgaug import transform from tensorpack.dataflow.imgaug import transform
from tensorpack.utils import logger
import pycocotools.mask as cocomask import pycocotools.mask as cocomask
import config
class DataFromListOfDict(RNGDataFlow): class DataFromListOfDict(RNGDataFlow):
def __init__(self, lst, keys, shuffle=False): def __init__(self, lst, keys, shuffle=False):
...@@ -138,21 +134,3 @@ def filter_boxes_inside_shape(boxes, shape): ...@@ -138,21 +134,3 @@ def filter_boxes_inside_shape(boxes, shape):
(boxes[:, 2] <= w) & (boxes[:, 2] <= w) &
(boxes[:, 3] <= h))[0] (boxes[:, 3] <= h))[0]
return indices, boxes[indices, :] return indices, boxes[indices, :]
def write_config_from_args(configs):
for cfg in configs:
k, v = cfg.split('=', maxsplit=1)
assert k in dir(config), "Unknown config key: {}".format(k)
oldv = getattr(config, k)
if not isinstance(oldv, six.text_type):
v = eval(v)
setattr(config, k, v)
def print_config():
logger.info("Config: ------------------------------------------")
for k in dir(config):
if k == k.upper():
logger.info("{} = {}".format(k, getattr(config, k)))
logger.info("--------------------------------------------------")
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: config.py # File: config.py
import numpy as np import pprint
__all__ = ['config']
class AttrDict():
def __getattr__(self, name):
ret = AttrDict()
setattr(self, name, ret)
return ret
def __str__(self):
return pprint.pformat(self.to_dict(), indent=1)
__repr__ = __str__
def to_dict(self):
"""Convert to a nested dict. """
return {k: v.to_dict() if isinstance(v, AttrDict) else v
for k, v in self.__dict__.items()}
def update_args(self, args):
"""Update from command line args. """
for cfg in args:
keys, v = cfg.split('=', maxsplit=1)
keylist = keys.split('.')
dic = self
for i, k in enumerate(keylist[:-1]):
assert k in dir(dic), "Unknown config key: {}".format(keys)
dic = getattr(dic, k)
key = keylist[-1]
oldv = getattr(dic, key)
if not isinstance(oldv, str):
v = eval(v)
setattr(dic, key, v)
config = AttrDict()
_C = config # short alias to avoid coding
# mode flags --------------------- # mode flags ---------------------
TRAINER = 'replicated' # options: 'horovod', 'replicated' _C.TRAINER = 'replicated' # options: 'horovod', 'replicated'
NUM_GPUS = None # by default, will be set from code _C.MODE_MASK = True
MODE_MASK = True _C.MODE_FPN = False
MODE_FPN = False
# dataset ----------------------- # dataset -----------------------
BASEDIR = '/path/to/your/COCO/DIR' _C.DATA.BASEDIR = '/path/to/your/COCO/DIR'
TRAIN_DATASET = ['train2014', 'valminusminival2014'] # i.e., trainval35k _C.DATA.TRAIN = ['train2014', 'valminusminival2014'] # i.e., trainval35k
VAL_DATASET = 'minival2014' # For now, only support evaluation on single dataset _C.DATA.VAL = 'minival2014' # For now, only support evaluation on single dataset
NUM_CLASS = 81 # 1 background + 80 categories _C.DATA.NUM_CLASS = 81 # 1 background + 80 categories
CLASS_NAMES = [] # NUM_CLASS strings. Needs to be populated later by data loader _C.DATA.CLASS_NAMES = [] # NUM_CLASS strings. Needs to be populated later by data loader
# basemodel ---------------------- # basemodel ----------------------
RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50 _C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101 # RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
FREEZE_AFFINE = False # do not train affine parameters inside BN _C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside BN
# Use a base model with TF-preferred pad mode, # Use a base model with TF-preferred pad mode,
# which may pad more pixels on right/bottom than top/left. # which may pad more pixels on right/bottom than top/left.
# TF_PAD_MODE=False is better for performance but will require a different base model. # TF_PAD_MODE=False is better for performance but will require a different base model.
# See https://github.com/tensorflow/tensorflow/issues/18213 # See https://github.com/tensorflow/tensorflow/issues/18213
TF_PAD_MODE = True _C.BACKBONE.TF_PAD_MODE = True
# schedule ----------------------- # schedule -----------------------
BASE_LR = 1e-2 # The schedule and learning rate here is defined for a total batch size of 8.
WARMUP = 1000 # in steps # If not running with 8 GPUs, they will be adjusted automatically in code.
STEPS_PER_EPOCH = 500 _C.TRAIN.NUM_GPUS = None # by default, will be set from code
_C.TRAIN.WEIGHT_DECAY = 1e-4
_C.TRAIN.BASE_LR = 1e-2
_C.TRAIN.WARMUP = 1000 # in steps
_C.TRAIN.STEPS_PER_EPOCH = 500
# LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron # LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
# LR_SCHEDULE = [150000, 230000, 280000] # roughly a "1.5x" schedule # LR_SCHEDULE = [150000, 230000, 280000] # roughly a "1.5x" schedule
LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# image resolution -------------------- # preprocessing --------------------
SHORT_EDGE_SIZE = 800 _C.PREPROC.SHORT_EDGE_SIZE = 800
MAX_SIZE = 1333 _C.PREPROC.MAX_SIZE = 1333
# Alternative (worse & faster) setting: 600, 1024 # Alternative old (worse & faster) setting: 600, 1024
# anchors ------------------------- # anchors -------------------------
ANCHOR_STRIDE = 16 _C.RPN.ANCHOR_STRIDE = 16
ANCHOR_STRIDES_FPN = (4, 8, 16, 32, 64) # strides for each FPN level. Must be the same length as ANCHOR_SIZES _C.RPN.ANCHOR_SIZES = (32, 64, 128, 256, 512) # sqrtarea of the anchor box
FPN_RESOLUTION_REQUIREMENT = 32 # image size into the backbone has to be multiple of this number _C.RPN.ANCHOR_RATIOS = (0.5, 1., 2.)
ANCHOR_SIZES = (32, 64, 128, 256, 512) # sqrtarea of the anchor box _C.RPN.NUM_ANCHOR = len(_C.RPN.ANCHOR_SIZES) * len(_C.RPN.ANCHOR_RATIOS)
ANCHOR_RATIOS = (0.5, 1., 2.) _C.RPN.POSITIVE_ANCHOR_THRES = 0.7
NUM_ANCHOR = len(ANCHOR_SIZES) * len(ANCHOR_RATIOS) _C.RPN.NEGATIVE_ANCHOR_THRES = 0.3
POSITIVE_ANCHOR_THRES = 0.7
NEGATIVE_ANCHOR_THRES = 0.3
BBOX_DECODE_CLIP = np.log(MAX_SIZE / 16.0) # to avoid too large numbers.
# rpn training ------------------------- # rpn training -------------------------
RPN_FG_RATIO = 0.5 # fg ratio among selected RPN anchors _C.RPN.FG_RATIO = 0.5 # fg ratio among selected RPN anchors
RPN_BATCH_PER_IM = 256 # total (across FPN levels) number of anchors that are marked valid _C.RPN.BATCH_PER_IM = 256 # total (across FPN levels) number of anchors that are marked valid
RPN_MIN_SIZE = 0 _C.RPN.MIN_SIZE = 0
RPN_PROPOSAL_NMS_THRESH = 0.7 _C.RPN.PROPOSAL_NMS_THRESH = 0.7
TRAIN_PRE_NMS_TOPK = 12000 _C.RPN.TRAIN_PRE_NMS_TOPK = 12000
TRAIN_POST_NMS_TOPK = 2000 _C.RPN.TRAIN_POST_NMS_TOPK = 2000
TRAIN_FPN_NMS_TOPK = 2000 _C.RPN.CROWD_OVERLAP_THRES = 0.7 # boxes overlapping crowd will be ignored.
CROWD_OVERLAP_THRES = 0.7 # boxes overlapping crowd will be ignored.
# fastrcnn training --------------------- # fastrcnn training ---------------------
FASTRCNN_BATCH_PER_IM = 512 _C.FRCNN.BATCH_PER_IM = 512
FASTRCNN_BBOX_REG_WEIGHTS = [10., 10., 5., 5.] # Better but non-standard setting: [20, 20, 10, 10] _C.FRCNN.BBOX_REG_WEIGHTS = [10., 10., 5., 5.] # Better but non-standard setting: [20, 20, 10, 10]
FASTRCNN_FG_THRESH = 0.5 _C.FRCNN.FG_THRESH = 0.5
FASTRCNN_FG_RATIO = 0.25 # fg ratio in a ROI batch _C.FRCNN.FG_RATIO = 0.25 # fg ratio in a ROI batch
# modeling ------------------------- # FPN -------------------------
FPN_NUM_CHANNEL = 256 _C.FPN.ANCHOR_STRIDES = (4, 8, 16, 32, 64) # strides for each FPN level. Must be the same length as ANCHOR_SIZES
_C.FPN.RESOLUTION_REQUIREMENT = 32 # image size into the backbone has to be multiple of this number
_C.FPN.NUM_CHANNEL = 256
# conv head and fc head are only used in FPN. # conv head and fc head are only used in FPN.
# For C4 models, the head is C5 # For C4 models, the head is C5
FPN_FASTRCNN_HEAD_FUNC = 'fastrcnn_2fc_head' # choices: fastrcnn_2fc_head, fastrcnn_4conv1fc_head _C.FPN.FRCNN_HEAD_FUNC = 'fastrcnn_2fc_head' # choices: fastrcnn_2fc_head, fastrcnn_4conv1fc_head
FASTRCNN_CONV_HEAD_DIM = 256 _C.FPN.FRCNN_CONV_HEAD_DIM = 256
FASTRCNN_FC_HEAD_DIM = 1024 _C.FPN.FRCNN_FC_HEAD_DIM = 1024
MASKRCNN_HEAD_DIM = 256
_C.RPN.TRAIN_FPN_NMS_TOPK = 2000
_C.RPN.TEST_FPN_NMS_TOPK = 1000
# Mask-RCNN
_C.MRCNN.HEAD_DIM = 256
# testing ----------------------- # testing -----------------------
TEST_PRE_NMS_TOPK = 6000 _C.RPN.TEST_PRE_NMS_TOPK = 6000
TEST_POST_NMS_TOPK = 1000 # if you encounter OOM in inference, set this to a smaller number _C.RPN.TEST_POST_NMS_TOPK = 1000 # if you encounter OOM in inference, set this to a smaller number
TEST_FPN_NMS_TOPK = 1000 _C.TEST.FRCNN_NMS_THRESH = 0.5
FASTRCNN_NMS_THRESH = 0.5 _C.TEST.RESULT_SCORE_THRESH = 0.05
RESULT_SCORE_THRESH = 0.05 _C.TEST.RESULT_SCORE_THRESH_VIS = 0.3 # only visualize confident results
RESULT_SCORE_THRESH_VIS = 0.3 # only visualize confident results _C.TEST.RESULTS_PER_IM = 100
RESULTS_PER_IM = 100
...@@ -21,7 +21,7 @@ from utils.np_box_ops import area as np_area ...@@ -21,7 +21,7 @@ from utils.np_box_ops import area as np_area
from common import ( from common import (
DataFromListOfDict, CustomResize, filter_boxes_inside_shape, DataFromListOfDict, CustomResize, filter_boxes_inside_shape,
box_to_point8, point8_to_box, segmentation_to_mask) box_to_point8, point8_to_box, segmentation_to_mask)
import config from config import config as cfg
class MalformedData(BaseException): class MalformedData(BaseException):
...@@ -30,8 +30,8 @@ class MalformedData(BaseException): ...@@ -30,8 +30,8 @@ class MalformedData(BaseException):
@memoized @memoized
def get_all_anchors( def get_all_anchors(
stride=config.ANCHOR_STRIDE, stride=cfg.RPN.ANCHOR_STRIDE,
sizes=config.ANCHOR_SIZES): sizes=cfg.RPN.ANCHOR_SIZES):
""" """
Get all anchors in the largest possible image, shifted, floatbox Get all anchors in the largest possible image, shifted, floatbox
Args: Args:
...@@ -49,14 +49,14 @@ def get_all_anchors( ...@@ -49,14 +49,14 @@ def get_all_anchors(
cell_anchors = generate_anchors( cell_anchors = generate_anchors(
stride, stride,
scales=np.array(sizes, dtype=np.float) / stride, scales=np.array(sizes, dtype=np.float) / stride,
ratios=np.array(config.ANCHOR_RATIOS, dtype=np.float)) ratios=np.array(cfg.RPN.ANCHOR_RATIOS, dtype=np.float))
# anchors are intbox here. # anchors are intbox here.
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride) # anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)
max_size = config.MAX_SIZE max_size = cfg.PREPROC.MAX_SIZE
if config.MODE_FPN: if cfg.MODE_FPN:
# TODO setting this in config is perhaps better # TODO setting this in config is perhaps better
size_mult = config.FPN_RESOLUTION_REQUIREMENT * 1. size_mult = cfg.FPN.RESOLUTION_REQUIREMENT * 1.
max_size = np.ceil(max_size / size_mult) * size_mult max_size = np.ceil(max_size / size_mult) * size_mult
field_size = int(np.ceil(max_size / stride)) field_size = int(np.ceil(max_size / stride))
shifts = np.arange(0, field_size) * stride shifts = np.arange(0, field_size) * stride
...@@ -81,8 +81,8 @@ def get_all_anchors( ...@@ -81,8 +81,8 @@ def get_all_anchors(
@memoized @memoized
def get_all_anchors_fpn( def get_all_anchors_fpn(
strides=config.ANCHOR_STRIDES_FPN, strides=cfg.FPN.ANCHOR_STRIDES,
sizes=config.ANCHOR_SIZES): sizes=cfg.RPN.ANCHOR_SIZES):
""" """
Returns: Returns:
[anchors]: each anchors is a SxSx NUM_ANCHOR_RATIOS x4 array. [anchors]: each anchors is a SxSx NUM_ANCHOR_RATIOS x4 array.
...@@ -132,8 +132,8 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes): ...@@ -132,8 +132,8 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
# the order of setting neg/pos labels matter # the order of setting neg/pos labels matter
anchor_labels[anchors_with_max_iou_per_gt] = 1 anchor_labels[anchors_with_max_iou_per_gt] = 1
anchor_labels[ious_max_per_anchor >= config.POSITIVE_ANCHOR_THRES] = 1 anchor_labels[ious_max_per_anchor >= cfg.RPN.POSITIVE_ANCHOR_THRES] = 1
anchor_labels[ious_max_per_anchor < config.NEGATIVE_ANCHOR_THRES] = 0 anchor_labels[ious_max_per_anchor < cfg.RPN.NEGATIVE_ANCHOR_THRES] = 0
# We can label all non-ignore candidate boxes which overlap crowd as ignore # We can label all non-ignore candidate boxes which overlap crowd as ignore
# But detectron did not do this. # But detectron did not do this.
...@@ -141,11 +141,11 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes): ...@@ -141,11 +141,11 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
# cand_inds = np.where(anchor_labels >= 0)[0] # cand_inds = np.where(anchor_labels >= 0)[0]
# cand_anchors = anchors[cand_inds] # cand_anchors = anchors[cand_inds]
# ious = np_iou(cand_anchors, crowd_boxes) # ious = np_iou(cand_anchors, crowd_boxes)
# overlap_with_crowd = cand_inds[ious.max(axis=1) > config.CROWD_OVERLAP_THRES] # overlap_with_crowd = cand_inds[ious.max(axis=1) > cfg.RPN.CROWD_OVERLAP_THRES]
# anchor_labels[overlap_with_crowd] = -1 # anchor_labels[overlap_with_crowd] = -1
# Subsample fg labels: ignore some fg if fg is too many # Subsample fg labels: ignore some fg if fg is too many
target_num_fg = int(config.RPN_BATCH_PER_IM * config.RPN_FG_RATIO) target_num_fg = int(cfg.RPN.BATCH_PER_IM * cfg.RPN.FG_RATIO)
fg_inds = filter_box_label(anchor_labels, 1, target_num_fg) fg_inds = filter_box_label(anchor_labels, 1, target_num_fg)
# Keep an image even if there is no foreground anchors # Keep an image even if there is no foreground anchors
# if len(fg_inds) == 0: # if len(fg_inds) == 0:
...@@ -156,14 +156,14 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes): ...@@ -156,14 +156,14 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
if old_num_bg == 0: if old_num_bg == 0:
# No valid bg in this image, skip. # No valid bg in this image, skip.
raise MalformedData("No valid background for RPN!") raise MalformedData("No valid background for RPN!")
target_num_bg = config.RPN_BATCH_PER_IM - len(fg_inds) target_num_bg = cfg.RPN.BATCH_PER_IM - len(fg_inds)
filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values
# Set anchor boxes: the best gt_box for each fg anchor # Set anchor boxes: the best gt_box for each fg anchor
anchor_boxes = np.zeros((NA, 4), dtype='float32') anchor_boxes = np.zeros((NA, 4), dtype='float32')
fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :] fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :]
anchor_boxes[fg_inds, :] = fg_boxes anchor_boxes[fg_inds, :] = fg_boxes
# assert len(fg_inds) + np.sum(anchor_labels == 0) == config.RPN_BATCH_PER_IM # assert len(fg_inds) + np.sum(anchor_labels == 0) == cfg.RPN.BATCH_PER_IM
return anchor_labels, anchor_boxes return anchor_labels, anchor_boxes
...@@ -192,12 +192,12 @@ def get_rpn_anchor_input(im, boxes, is_crowd): ...@@ -192,12 +192,12 @@ def get_rpn_anchor_input(im, boxes, is_crowd):
# Fill them back to original size: fHxfWx1, fHxfWx4 # Fill them back to original size: fHxfWx1, fHxfWx4
anchorH, anchorW = all_anchors.shape[:2] anchorH, anchorW = all_anchors.shape[:2]
featuremap_labels = -np.ones((anchorH * anchorW * config.NUM_ANCHOR, ), dtype='int32') featuremap_labels = -np.ones((anchorH * anchorW * cfg.RPN.NUM_ANCHOR, ), dtype='int32')
featuremap_labels[inside_ind] = anchor_labels featuremap_labels[inside_ind] = anchor_labels
featuremap_labels = featuremap_labels.reshape((anchorH, anchorW, config.NUM_ANCHOR)) featuremap_labels = featuremap_labels.reshape((anchorH, anchorW, cfg.RPN.NUM_ANCHOR))
featuremap_boxes = np.zeros((anchorH * anchorW * config.NUM_ANCHOR, 4), dtype='float32') featuremap_boxes = np.zeros((anchorH * anchorW * cfg.RPN.NUM_ANCHOR, 4), dtype='float32')
featuremap_boxes[inside_ind, :] = anchor_gt_boxes featuremap_boxes[inside_ind, :] = anchor_gt_boxes
featuremap_boxes = featuremap_boxes.reshape((anchorH, anchorW, config.NUM_ANCHOR, 4)) featuremap_boxes = featuremap_boxes.reshape((anchorH, anchorW, cfg.RPN.NUM_ANCHOR, 4))
return featuremap_labels, featuremap_boxes return featuremap_labels, featuremap_boxes
...@@ -233,7 +233,7 @@ def get_multilevel_rpn_anchor_input(im, boxes, is_crowd): ...@@ -233,7 +233,7 @@ def get_multilevel_rpn_anchor_input(im, boxes, is_crowd):
start = 0 start = 0
multilevel_inputs = [] multilevel_inputs = []
for level_anchor in anchors_per_level: for level_anchor in anchors_per_level:
assert level_anchor.shape[2] == len(config.ANCHOR_RATIOS) assert level_anchor.shape[2] == len(cfg.RPN.ANCHOR_RATIOS)
anchor_shape = level_anchor.shape[:3] # fHxfWxNUM_ANCHOR_RATIOS anchor_shape = level_anchor.shape[:3] # fHxfWxNUM_ANCHOR_RATIOS
num_anchor_this_level = np.prod(anchor_shape) num_anchor_this_level = np.prod(anchor_shape)
end = start + num_anchor_this_level end = start + num_anchor_this_level
...@@ -263,7 +263,7 @@ def get_train_dataflow(): ...@@ -263,7 +263,7 @@ def get_train_dataflow():
""" """
imgs = COCODetection.load_many( imgs = COCODetection.load_many(
config.BASEDIR, config.TRAIN_DATASET, add_gt=True, add_mask=config.MODE_MASK) cfg.DATA.BASEDIR, cfg.DATA.TRAIN, add_gt=True, add_mask=cfg.MODE_MASK)
""" """
To train on your own data, change this to your loader. To train on your own data, change this to your loader.
Produce "imgs" as a list of dict, in the dict the following keys are needed for training: Produce "imgs" as a list of dict, in the dict the following keys are needed for training:
...@@ -292,7 +292,7 @@ def get_train_dataflow(): ...@@ -292,7 +292,7 @@ def get_train_dataflow():
ds = DataFromList(imgs, shuffle=True) ds = DataFromList(imgs, shuffle=True)
aug = imgaug.AugmentorList( aug = imgaug.AugmentorList(
[CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE), [CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
imgaug.Flip(horiz=True)]) imgaug.Flip(horiz=True)])
def preprocess(img): def preprocess(img):
...@@ -313,7 +313,7 @@ def get_train_dataflow(): ...@@ -313,7 +313,7 @@ def get_train_dataflow():
# rpn anchor: # rpn anchor:
try: try:
if config.MODE_FPN: if cfg.MODE_FPN:
multilevel_anchor_inputs = get_multilevel_rpn_anchor_input(im, boxes, is_crowd) multilevel_anchor_inputs = get_multilevel_rpn_anchor_input(im, boxes, is_crowd)
anchor_inputs = itertools.chain.from_iterable(multilevel_anchor_inputs) anchor_inputs = itertools.chain.from_iterable(multilevel_anchor_inputs)
else: else:
...@@ -331,7 +331,7 @@ def get_train_dataflow(): ...@@ -331,7 +331,7 @@ def get_train_dataflow():
ret = [im] + list(anchor_inputs) + [boxes, klass] ret = [im] + list(anchor_inputs) + [boxes, klass]
if config.MODE_MASK: if cfg.MODE_MASK:
# augmentation will modify the polys in-place # augmentation will modify the polys in-place
segmentation = copy.deepcopy(img['segmentation']) segmentation = copy.deepcopy(img['segmentation'])
segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]] segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]]
...@@ -353,7 +353,7 @@ def get_train_dataflow(): ...@@ -353,7 +353,7 @@ def get_train_dataflow():
# tpviz.interactive_imshow(viz) # tpviz.interactive_imshow(viz)
return ret return ret
if config.TRAINER == 'horovod': if cfg.TRAINER == 'horovod':
ds = MultiThreadMapData(ds, 5, preprocess) ds = MultiThreadMapData(ds, 5, preprocess)
# MPI does not like fork() # MPI does not like fork()
else: else:
...@@ -362,7 +362,7 @@ def get_train_dataflow(): ...@@ -362,7 +362,7 @@ def get_train_dataflow():
def get_eval_dataflow(): def get_eval_dataflow():
imgs = COCODetection.load_many(config.BASEDIR, config.VAL_DATASET, add_gt=False) imgs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False)
# no filter for training # no filter for training
ds = DataFromListOfDict(imgs, ['file_name', 'id']) ds = DataFromListOfDict(imgs, ['file_name', 'id'])
...@@ -371,7 +371,7 @@ def get_eval_dataflow(): ...@@ -371,7 +371,7 @@ def get_eval_dataflow():
assert im is not None, fname assert im is not None, fname
return im return im
ds = MapDataComponent(ds, f, 0) ds = MapDataComponent(ds, f, 0)
if config.TRAINER != 'horovod': if cfg.TRAINER != 'horovod':
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
return ds return ds
...@@ -379,7 +379,7 @@ def get_eval_dataflow(): ...@@ -379,7 +379,7 @@ def get_eval_dataflow():
if __name__ == '__main__': if __name__ == '__main__':
import os import os
from tensorpack.dataflow import PrintData from tensorpack.dataflow import PrintData
config.BASEDIR = os.path.expanduser('~/data/coco') cfg.DATA.BASEDIR = os.path.expanduser('~/data/coco')
ds = get_train_dataflow() ds = get_train_dataflow()
ds = PrintData(ds, 100) ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start() TestDataSpeed(ds, 50000).start()
......
...@@ -15,7 +15,7 @@ import pycocotools.mask as cocomask ...@@ -15,7 +15,7 @@ import pycocotools.mask as cocomask
from coco import COCOMeta from coco import COCOMeta
from common import CustomResize, clip_boxes from common import CustomResize, clip_boxes
import config from config import config as cfg
DetectionResult = namedtuple( DetectionResult = namedtuple(
'DetectionResult', 'DetectionResult',
...@@ -69,7 +69,7 @@ def detect_one_image(img, model_func): ...@@ -69,7 +69,7 @@ def detect_one_image(img, model_func):
""" """
orig_shape = img.shape[:2] orig_shape = img.shape[:2]
resizer = CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE) resizer = CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE)
resized_img = resizer.augment(img) resized_img = resizer.augment(img)
scale = (resized_img.shape[0] * 1.0 / img.shape[0] + resized_img.shape[1] * 1.0 / img.shape[1]) / 2 scale = (resized_img.shape[0] * 1.0 / img.shape[0] + resized_img.shape[1] * 1.0 / img.shape[1]) / 2
boxes, probs, labels, *masks = model_func(resized_img) boxes, probs, labels, *masks = model_func(resized_img)
...@@ -131,10 +131,10 @@ def eval_coco(df, detect_func): ...@@ -131,10 +131,10 @@ def eval_coco(df, detect_func):
# https://github.com/pdollar/coco/blob/master/PythonAPI/pycocoEvalDemo.ipynb # https://github.com/pdollar/coco/blob/master/PythonAPI/pycocoEvalDemo.ipynb
def print_evaluation_scores(json_file): def print_evaluation_scores(json_file):
ret = {} ret = {}
assert config.BASEDIR and os.path.isdir(config.BASEDIR) assert cfg.DATA.BASEDIR and os.path.isdir(cfg.DATA.BASEDIR)
annofile = os.path.join( annofile = os.path.join(
config.BASEDIR, 'annotations', cfg.DATA.BASEDIR, 'annotations',
'instances_{}.json'.format(config.VAL_DATASET)) 'instances_{}.json'.format(cfg.DATA.VAL))
coco = COCO(annofile) coco = COCO(annofile)
cocoDt = coco.loadRes(json_file) cocoDt = coco.loadRes(json_file)
cocoEval = COCOeval(coco, cocoDt, 'bbox') cocoEval = COCOeval(coco, cocoDt, 'bbox')
...@@ -145,7 +145,7 @@ def print_evaluation_scores(json_file): ...@@ -145,7 +145,7 @@ def print_evaluation_scores(json_file):
for k in range(6): for k in range(6):
ret['mAP(bbox)/' + fields[k]] = cocoEval.stats[k] ret['mAP(bbox)/' + fields[k]] = cocoEval.stats[k]
if config.MODE_MASK: if cfg.MODE_MASK:
cocoEval = COCOeval(coco, cocoDt, 'segm') cocoEval = COCOeval(coco, cocoDt, 'segm')
cocoEval.evaluate() cocoEval.evaluate()
cocoEval.accumulate() cocoEval.accumulate()
......
...@@ -15,7 +15,7 @@ from tensorpack.models import ( ...@@ -15,7 +15,7 @@ from tensorpack.models import (
from utils.box_ops import pairwise_iou from utils.box_ops import pairwise_iou
from utils.box_ops import area as tf_area from utils.box_ops import area as tf_area
from model_box import roi_align, clip_boxes from model_box import roi_align, clip_boxes
import config from config import config as cfg
@layer_register(log_shape=True) @layer_register(log_shape=True)
...@@ -91,7 +91,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits): ...@@ -91,7 +91,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
placeholder = 0. placeholder = 0.
label_loss = tf.nn.sigmoid_cross_entropy_with_logits( label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits) labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits)
label_loss = tf.reduce_sum(label_loss) * (1. / config.RPN_BATCH_PER_IM) label_loss = tf.reduce_sum(label_loss) * (1. / cfg.RPN.BATCH_PER_IM)
label_loss = tf.where(tf.equal(nr_valid, 0), placeholder, label_loss, name='label_loss') label_loss = tf.where(tf.equal(nr_valid, 0), placeholder, label_loss, name='label_loss')
pos_anchor_boxes = tf.boolean_mask(anchor_boxes, pos_mask) pos_anchor_boxes = tf.boolean_mask(anchor_boxes, pos_mask)
...@@ -100,7 +100,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits): ...@@ -100,7 +100,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
box_loss = tf.losses.huber_loss( box_loss = tf.losses.huber_loss(
pos_anchor_boxes, pos_box_logits, delta=delta, pos_anchor_boxes, pos_box_logits, delta=delta,
reduction=tf.losses.Reduction.SUM) / delta reduction=tf.losses.Reduction.SUM) / delta
box_loss = box_loss * (1. / config.RPN_BATCH_PER_IM) box_loss = box_loss * (1. / cfg.RPN.BATCH_PER_IM)
box_loss = tf.where(tf.equal(nr_pos, 0), placeholder, box_loss, name='box_loss') box_loss = tf.where(tf.equal(nr_pos, 0), placeholder, box_loss, name='box_loss')
add_moving_summary(label_loss, box_loss, nr_valid, nr_pos) add_moving_summary(label_loss, box_loss, nr_valid, nr_pos)
...@@ -139,7 +139,7 @@ def generate_rpn_proposals(boxes, scores, img_shape, ...@@ -139,7 +139,7 @@ def generate_rpn_proposals(boxes, scores, img_shape,
topk_boxes_x1y1, topk_boxes_x2y2 = tf.split(topk_boxes_x1y1x2y2, 2, axis=1) topk_boxes_x1y1, topk_boxes_x2y2 = tf.split(topk_boxes_x1y1x2y2, 2, axis=1)
# nx1x2 each # nx1x2 each
wbhb = tf.squeeze(topk_boxes_x2y2 - topk_boxes_x1y1, axis=1) wbhb = tf.squeeze(topk_boxes_x2y2 - topk_boxes_x1y1, axis=1)
valid = tf.reduce_all(wbhb > config.RPN_MIN_SIZE, axis=1) # n, valid = tf.reduce_all(wbhb > cfg.RPN.MIN_SIZE, axis=1) # n,
topk_valid_boxes_x1y1x2y2 = tf.boolean_mask(topk_boxes_x1y1x2y2, valid) topk_valid_boxes_x1y1x2y2 = tf.boolean_mask(topk_boxes_x1y1x2y2, valid)
topk_valid_scores = tf.boolean_mask(topk_scores, valid) topk_valid_scores = tf.boolean_mask(topk_scores, valid)
...@@ -152,7 +152,7 @@ def generate_rpn_proposals(boxes, scores, img_shape, ...@@ -152,7 +152,7 @@ def generate_rpn_proposals(boxes, scores, img_shape,
# TODO use exp to work around a bug in TF1.9: https://github.com/tensorflow/tensorflow/issues/19578 # TODO use exp to work around a bug in TF1.9: https://github.com/tensorflow/tensorflow/issues/19578
tf.exp(topk_valid_scores), tf.exp(topk_valid_scores),
max_output_size=post_nms_topk, max_output_size=post_nms_topk,
iou_threshold=config.RPN_PROPOSAL_NMS_THRESH) iou_threshold=cfg.RPN.PROPOSAL_NMS_THRESH)
topk_valid_boxes = tf.reshape(topk_valid_boxes_x1y1x2y2, (-1, 4)) topk_valid_boxes = tf.reshape(topk_valid_boxes_x1y1x2y2, (-1, 4))
final_boxes = tf.gather(topk_valid_boxes, nms_indices) final_boxes = tf.gather(topk_valid_boxes, nms_indices)
...@@ -209,17 +209,17 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -209,17 +209,17 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
# #proposal=n+m from now on # #proposal=n+m from now on
def sample_fg_bg(iou): def sample_fg_bg(iou):
fg_mask = tf.reduce_max(iou, axis=1) >= config.FASTRCNN_FG_THRESH fg_mask = tf.reduce_max(iou, axis=1) >= cfg.FRCNN.FG_THRESH
fg_inds = tf.reshape(tf.where(fg_mask), [-1]) fg_inds = tf.reshape(tf.where(fg_mask), [-1])
num_fg = tf.minimum(int( num_fg = tf.minimum(int(
config.FASTRCNN_BATCH_PER_IM * config.FASTRCNN_FG_RATIO), cfg.FRCNN.BATCH_PER_IM * cfg.FRCNN.FG_RATIO),
tf.size(fg_inds), name='num_fg') tf.size(fg_inds), name='num_fg')
fg_inds = tf.random_shuffle(fg_inds)[:num_fg] fg_inds = tf.random_shuffle(fg_inds)[:num_fg]
bg_inds = tf.reshape(tf.where(tf.logical_not(fg_mask)), [-1]) bg_inds = tf.reshape(tf.where(tf.logical_not(fg_mask)), [-1])
num_bg = tf.minimum( num_bg = tf.minimum(
config.FASTRCNN_BATCH_PER_IM - num_fg, cfg.FRCNN.BATCH_PER_IM - num_fg,
tf.size(bg_inds), name='num_bg') tf.size(bg_inds), name='num_bg')
bg_inds = tf.random_shuffle(bg_inds)[:num_bg] bg_inds = tf.random_shuffle(bg_inds)[:num_bg]
...@@ -274,7 +274,7 @@ def fastrcnn_2fc_head(feature, num_classes): ...@@ -274,7 +274,7 @@ def fastrcnn_2fc_head(feature, num_classes):
Returns: Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4) cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
""" """
dim = config.FASTRCNN_FC_HEAD_DIM dim = cfg.FPN.FRCNN_FC_HEAD_DIM
init = tf.variance_scaling_initializer() init = tf.variance_scaling_initializer()
hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, activation=tf.nn.relu) hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, activation=tf.nn.relu)
hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, activation=tf.nn.relu) hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, activation=tf.nn.relu)
...@@ -297,8 +297,8 @@ def fastrcnn_Xconv1fc_head(feature, num_classes, num_convs): ...@@ -297,8 +297,8 @@ def fastrcnn_Xconv1fc_head(feature, num_classes, num_convs):
kernel_initializer=tf.variance_scaling_initializer( kernel_initializer=tf.variance_scaling_initializer(
scale=2.0, mode='fan_out', distribution='normal')): scale=2.0, mode='fan_out', distribution='normal')):
for k in range(num_convs): for k in range(num_convs):
l = Conv2D('conv{}'.format(k), l, config.FASTRCNN_CONV_HEAD_DIM, 3, activation=tf.nn.relu) l = Conv2D('conv{}'.format(k), l, cfg.FPN.FRCNN_CONV_HEAD_DIM, 3, activation=tf.nn.relu)
l = FullyConnected('fc', l, config.FASTRCNN_FC_HEAD_DIM, l = FullyConnected('fc', l, cfg.FPN.FRCNN_FC_HEAD_DIM,
kernel_initializer=tf.variance_scaling_initializer(), activation=tf.nn.relu) kernel_initializer=tf.variance_scaling_initializer(), activation=tf.nn.relu)
return fastrcnn_outputs('outputs', l, num_classes) return fastrcnn_outputs('outputs', l, num_classes)
...@@ -356,8 +356,8 @@ def fastrcnn_predictions(boxes, probs): ...@@ -356,8 +356,8 @@ def fastrcnn_predictions(boxes, probs):
boxes: n#catx4 floatbox in float32 boxes: n#catx4 floatbox in float32
probs: nx#class probs: nx#class
""" """
assert boxes.shape[1] == config.NUM_CLASS - 1 assert boxes.shape[1] == cfg.DATA.NUM_CLASS - 1
assert probs.shape[1] == config.NUM_CLASS assert probs.shape[1] == cfg.DATA.NUM_CLASS
boxes = tf.transpose(boxes, [1, 0, 2]) # #catxnx4 boxes = tf.transpose(boxes, [1, 0, 2]) # #catxnx4
probs = tf.transpose(probs[:, 1:], [1, 0]) # #catxn probs = tf.transpose(probs[:, 1:], [1, 0]) # #catxn
...@@ -371,12 +371,12 @@ def fastrcnn_predictions(boxes, probs): ...@@ -371,12 +371,12 @@ def fastrcnn_predictions(boxes, probs):
prob, box = X prob, box = X
output_shape = tf.shape(prob) output_shape = tf.shape(prob)
# filter by score threshold # filter by score threshold
ids = tf.reshape(tf.where(prob > config.RESULT_SCORE_THRESH), [-1]) ids = tf.reshape(tf.where(prob > cfg.TEST.RESULT_SCORE_THRESH), [-1])
prob = tf.gather(prob, ids) prob = tf.gather(prob, ids)
box = tf.gather(box, ids) box = tf.gather(box, ids)
# NMS within each class # NMS within each class
selection = tf.image.non_max_suppression( selection = tf.image.non_max_suppression(
box, prob, config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH) box, prob, cfg.TEST.RESULTS_PER_IM, cfg.TEST.FRCNN_NMS_THRESH)
selection = tf.to_int32(tf.gather(ids, selection)) selection = tf.to_int32(tf.gather(ids, selection))
# sort available in TF>1.4.0 # sort available in TF>1.4.0
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING') # sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
...@@ -396,7 +396,7 @@ def fastrcnn_predictions(boxes, probs): ...@@ -396,7 +396,7 @@ def fastrcnn_predictions(boxes, probs):
# filter again by sorting scores # filter again by sorting scores
topk_probs, topk_indices = tf.nn.top_k( topk_probs, topk_indices = tf.nn.top_k(
probs, probs,
tf.minimum(config.RESULTS_PER_IM, tf.size(probs)), tf.minimum(cfg.TEST.RESULTS_PER_IM, tf.size(probs)),
sorted=False) sorted=False)
filtered_selection = tf.gather(selected_indices, topk_indices) filtered_selection = tf.gather(selected_indices, topk_indices)
filtered_selection = tf.reverse(filtered_selection, axis=[1], name='filtered_indices') filtered_selection = tf.reverse(filtered_selection, axis=[1], name='filtered_indices')
...@@ -420,8 +420,8 @@ def maskrcnn_upXconv_head(feature, num_class, num_convs): ...@@ -420,8 +420,8 @@ def maskrcnn_upXconv_head(feature, num_class, num_convs):
scale=2.0, mode='fan_out', distribution='normal')): scale=2.0, mode='fan_out', distribution='normal')):
# c2's MSRAFill is fan_out # c2's MSRAFill is fan_out
for k in range(num_convs): for k in range(num_convs):
l = Conv2D('fcn{}'.format(k), l, config.MASKRCNN_HEAD_DIM, 3, activation=tf.nn.relu) l = Conv2D('fcn{}'.format(k), l, cfg.MRCNN.HEAD_DIM, 3, activation=tf.nn.relu)
l = Conv2DTranspose('deconv', l, config.MASKRCNN_HEAD_DIM, 2, strides=2, activation=tf.nn.relu) l = Conv2DTranspose('deconv', l, cfg.MRCNN.HEAD_DIM, 2, strides=2, activation=tf.nn.relu)
l = Conv2D('conv', l, num_class - 1, 1) l = Conv2D('conv', l, num_class - 1, 1)
return l return l
...@@ -475,7 +475,7 @@ def fpn_model(features): ...@@ -475,7 +475,7 @@ def fpn_model(features):
[tf.Tensor]: FPN features p2-p6 [tf.Tensor]: FPN features p2-p6
""" """
assert len(features) == 4, features assert len(features) == 4, features
num_channel = config.FPN_NUM_CHANNEL num_channel = cfg.FPN.NUM_CHANNEL
def upsample2x(name, x): def upsample2x(name, x):
return FixedUnPooling( return FixedUnPooling(
...@@ -560,7 +560,7 @@ def multilevel_roi_align(features, rcnn_boxes, resolution): ...@@ -560,7 +560,7 @@ def multilevel_roi_align(features, rcnn_boxes, resolution):
# Crop patches from corresponding levels # Crop patches from corresponding levels
for i, boxes, featuremap in zip(itertools.count(), level_boxes, features): for i, boxes, featuremap in zip(itertools.count(), level_boxes, features):
with tf.name_scope('roi_level{}'.format(i + 2)): with tf.name_scope('roi_level{}'.format(i + 2)):
boxes_on_featuremap = boxes * (1.0 / config.ANCHOR_STRIDES_FPN[i]) boxes_on_featuremap = boxes * (1.0 / cfg.FPN.ANCHOR_STRIDES[i])
all_rois.append(roi_align(featuremap, boxes_on_featuremap, resolution)) all_rois.append(roi_align(featuremap, boxes_on_featuremap, resolution))
all_rois = tf.concat(all_rois, axis=0) # NCHW all_rois = tf.concat(all_rois, axis=0) # NCHW
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: model_box.py # File: model_box.py
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
import config from config import config
@under_name_scope() @under_name_scope()
...@@ -41,8 +42,8 @@ def decode_bbox_target(box_predictions, anchors): ...@@ -41,8 +42,8 @@ def decode_bbox_target(box_predictions, anchors):
waha = anchors_x2y2 - anchors_x1y1 waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5 xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
wbhb = tf.exp(tf.minimum( clip = np.log(config.PREPROC.MAX_SIZE / 16.)
box_pred_twth, config.BBOX_DECODE_CLIP)) * waha wbhb = tf.exp(tf.minimum(box_pred_twth, clip)) * waha
xbyb = box_pred_txty * waha + xaya xbyb = box_pred_txty * waha + xaya
x1y1 = xbyb - wbhb * 0.5 x1y1 = xbyb - wbhb * 0.5
x2y2 = xbyb + wbhb * 0.5 # (...)x1x2 x2y2 = xbyb + wbhb * 0.5 # (...)x1x2
...@@ -174,7 +175,6 @@ if __name__ == '__main__': ...@@ -174,7 +175,6 @@ if __name__ == '__main__':
Demonstrate what's wrong with tf.image.crop_and_resize: Demonstrate what's wrong with tf.image.crop_and_resize:
""" """
import tensorflow.contrib.eager as tfe import tensorflow.contrib.eager as tfe
import numpy as np
tfe.enable_eager_execution() tfe.enable_eager_execution()
# want to crop 2x2 out of a 5x5 image, and resize to 4x4 # want to crop 2x2 out of a 5x5 image, and resize to 4x4
......
...@@ -47,10 +47,9 @@ from data import ( ...@@ -47,10 +47,9 @@ from data import (
from viz import ( from viz import (
draw_annotation, draw_proposal_recall, draw_annotation, draw_proposal_recall,
draw_predictions, draw_final_outputs) draw_predictions, draw_final_outputs)
from common import print_config, write_config_from_args
from eval import ( from eval import (
eval_coco, detect_one_image, print_evaluation_scores, DetectionResult) eval_coco, detect_one_image, print_evaluation_scores, DetectionResult)
import config from config import config as cfg
class DetectionModel(ModelDesc): class DetectionModel(ModelDesc):
...@@ -81,12 +80,12 @@ class DetectionModel(ModelDesc): ...@@ -81,12 +80,12 @@ class DetectionModel(ModelDesc):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False) lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
tf.summary.scalar('learning_rate-summary', lr) tf.summary.scalar('learning_rate-summary', lr)
factor = config.NUM_GPUS / 8. factor = cfg.TRAIN.NUM_GPUS / 8.
if factor != 1: if factor != 1:
lr = lr * factor lr = lr * factor
opt = tf.train.MomentumOptimizer(lr, 0.9) opt = tf.train.MomentumOptimizer(lr, 0.9)
if config.NUM_GPUS < 8: if cfg.TRAIN.NUM_GPUS < 8:
opt = optimizer.AccumGradOptimizer(opt, 8 // config.NUM_GPUS) opt = optimizer.AccumGradOptimizer(opt, 8 // cfg.TRAIN.NUM_GPUS)
return opt return opt
def fastrcnn_training(self, image, def fastrcnn_training(self, image,
...@@ -111,7 +110,7 @@ class DetectionModel(ModelDesc): ...@@ -111,7 +110,7 @@ class DetectionModel(ModelDesc):
tf.summary.image('viz', fg_sampled_patches, max_outputs=30) tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
encoded_boxes = encode_bbox_target( encoded_boxes = encode_bbox_target(
gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS, dtype=tf.float32) gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, rcnn_label_logits, rcnn_labels, rcnn_label_logits,
encoded_boxes, encoded_boxes,
...@@ -132,12 +131,12 @@ class DetectionModel(ModelDesc): ...@@ -132,12 +131,12 @@ class DetectionModel(ModelDesc):
labels (m): each >= 1 labels (m): each >= 1
""" """
rcnn_box_logits = rcnn_box_logits[:, 1:, :] rcnn_box_logits = rcnn_box_logits[:, 1:, :]
rcnn_box_logits.set_shape([None, config.NUM_CLASS - 1, None]) rcnn_box_logits.set_shape([None, cfg.DATA.NUM_CLASS - 1, None])
label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class
anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4 anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, cfg.DATA.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4
decoded_boxes = decode_bbox_target( decoded_boxes = decode_bbox_target(
rcnn_box_logits / rcnn_box_logits /
tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS, dtype=tf.float32), anchors) tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32), anchors)
decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes') decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')
# indices: Nx2. Each index into (#proposal, #category) # indices: Nx2. Each index into (#proposal, #category)
...@@ -156,7 +155,7 @@ class DetectionModel(ModelDesc): ...@@ -156,7 +155,7 @@ class DetectionModel(ModelDesc):
[str]: output names [str]: output names
""" """
out = ['final_boxes', 'final_probs', 'final_labels'] out = ['final_boxes', 'final_probs', 'final_labels']
if config.MODE_MASK: if cfg.MODE_MASK:
out.append('final_masks') out.append('final_masks')
return ['image'], out return ['image'], out
...@@ -165,11 +164,11 @@ class ResNetC4Model(DetectionModel): ...@@ -165,11 +164,11 @@ class ResNetC4Model(DetectionModel):
def inputs(self): def inputs(self):
ret = [ ret = [
tf.placeholder(tf.float32, (None, None, 3), 'image'), tf.placeholder(tf.float32, (None, None, 3), 'image'),
tf.placeholder(tf.int32, (None, None, config.NUM_ANCHOR), 'anchor_labels'), tf.placeholder(tf.int32, (None, None, cfg.RPN.NUM_ANCHOR), 'anchor_labels'),
tf.placeholder(tf.float32, (None, None, config.NUM_ANCHOR, 4), 'anchor_boxes'), tf.placeholder(tf.float32, (None, None, cfg.RPN.NUM_ANCHOR, 4), 'anchor_boxes'),
tf.placeholder(tf.float32, (None, 4), 'gt_boxes'), tf.placeholder(tf.float32, (None, 4), 'gt_boxes'),
tf.placeholder(tf.int64, (None,), 'gt_labels')] # all > 0 tf.placeholder(tf.int64, (None,), 'gt_labels')] # all > 0
if config.MODE_MASK: if cfg.MODE_MASK:
ret.append( ret.append(
tf.placeholder(tf.uint8, (None, None, None), 'gt_masks') tf.placeholder(tf.uint8, (None, None, None), 'gt_masks')
) # NR_GT x height x width ) # NR_GT x height x width
...@@ -177,14 +176,14 @@ class ResNetC4Model(DetectionModel): ...@@ -177,14 +176,14 @@ class ResNetC4Model(DetectionModel):
def build_graph(self, *inputs): def build_graph(self, *inputs):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
if config.MODE_MASK: if cfg.MODE_MASK:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
else: else:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
image = self.preprocess(image) # 1CHW image = self.preprocess(image) # 1CHW
featuremap = resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3]) featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, cfg.RPN.NUM_ANCHOR)
fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap( fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap(
featuremap, get_all_anchors(), anchor_labels, anchor_boxes) featuremap, get_all_anchors(), anchor_labels, anchor_boxes)
...@@ -196,8 +195,8 @@ class ResNetC4Model(DetectionModel): ...@@ -196,8 +195,8 @@ class ResNetC4Model(DetectionModel):
tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]), tf.reshape(rpn_label_logits, [-1]),
image_shape2d, image_shape2d,
config.TRAIN_PRE_NMS_TOPK if is_training else config.TEST_PRE_NMS_TOPK, cfg.RPN.TRAIN_PRE_NMS_TOPK if is_training else cfg.RPN.TEST_PRE_NMS_TOPK,
config.TRAIN_POST_NMS_TOPK if is_training else config.TEST_POST_NMS_TOPK) cfg.RPN.TRAIN_POST_NMS_TOPK if is_training else cfg.RPN.TEST_POST_NMS_TOPK)
if is_training: if is_training:
# sample proposal boxes in training # sample proposal boxes in training
...@@ -208,13 +207,13 @@ class ResNetC4Model(DetectionModel): ...@@ -208,13 +207,13 @@ class ResNetC4Model(DetectionModel):
# Use all proposal boxes in inference # Use all proposal boxes in inference
rcnn_boxes = proposal_boxes rcnn_boxes = proposal_boxes
boxes_on_featuremap = rcnn_boxes * (1.0 / config.ANCHOR_STRIDE) boxes_on_featuremap = rcnn_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7 feature_fastrcnn = resnet_conv5(roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1]) # nxcx7x7
# Keep C5 feature to be shared with mask branch # Keep C5 feature to be shared with mask branch
feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first') feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, config.NUM_CLASS) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS)
if is_training: if is_training:
# rpn loss # rpn loss
...@@ -232,13 +231,13 @@ class ResNetC4Model(DetectionModel): ...@@ -232,13 +231,13 @@ class ResNetC4Model(DetectionModel):
image, rcnn_labels, fg_sampled_boxes, image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
if config.MODE_MASK: if cfg.MODE_MASK:
# maskrcnn loss # maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
# In training, mask branch shares the same C5 feature. # In training, mask branch shares the same C5 feature.
fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample) fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
mask_logits = maskrcnn_upXconv_head( mask_logits = maskrcnn_upXconv_head(
'maskrcnn', fg_feature, config.NUM_CLASS, num_convs=0) # #fg x #cat x 14x14 'maskrcnn', fg_feature, cfg.DATA.NUM_CLASS, num_convs=0) # #fg x #cat x 14x14
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks, 1), tf.expand_dims(gt_masks, 1),
...@@ -252,7 +251,7 @@ class ResNetC4Model(DetectionModel): ...@@ -252,7 +251,7 @@ class ResNetC4Model(DetectionModel):
wd_cost = regularize_cost( wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(1e-4), name='wd_cost') l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
total_cost = tf.add_n([ total_cost = tf.add_n([
rpn_label_loss, rpn_box_loss, rpn_label_loss, rpn_box_loss,
...@@ -266,11 +265,11 @@ class ResNetC4Model(DetectionModel): ...@@ -266,11 +265,11 @@ class ResNetC4Model(DetectionModel):
final_boxes, final_labels = self.fastrcnn_inference( final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if config.MODE_MASK: if cfg.MODE_MASK:
roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14) roi_resized = roi_align(featuremap, final_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE), 14)
feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) feature_maskrcnn = resnet_conv5(roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1])
mask_logits = maskrcnn_upXconv_head( mask_logits = maskrcnn_upXconv_head(
'maskrcnn', feature_maskrcnn, config.NUM_CLASS, 0) # #result x #cat x 14x14 'maskrcnn', feature_maskrcnn, cfg.DATA.NUM_CLASS, 0) # #result x #cat x 14x14
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14 final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14
tf.sigmoid(final_mask_logits, name='final_masks') tf.sigmoid(final_mask_logits, name='final_masks')
...@@ -280,8 +279,8 @@ class ResNetFPNModel(DetectionModel): ...@@ -280,8 +279,8 @@ class ResNetFPNModel(DetectionModel):
def inputs(self): def inputs(self):
ret = [ ret = [
tf.placeholder(tf.float32, (None, None, 3), 'image')] tf.placeholder(tf.float32, (None, None, 3), 'image')]
num_anchors = len(config.ANCHOR_RATIOS) num_anchors = len(cfg.RPN.ANCHOR_RATIOS)
for k in range(len(config.ANCHOR_STRIDES_FPN)): for k in range(len(cfg.FPN.ANCHOR_STRIDES)):
ret.extend([ ret.extend([
tf.placeholder(tf.int32, (None, None, num_anchors), tf.placeholder(tf.int32, (None, None, num_anchors),
'anchor_labels_lvl{}'.format(k + 2)), 'anchor_labels_lvl{}'.format(k + 2)),
...@@ -290,33 +289,33 @@ class ResNetFPNModel(DetectionModel): ...@@ -290,33 +289,33 @@ class ResNetFPNModel(DetectionModel):
ret.extend([ ret.extend([
tf.placeholder(tf.float32, (None, 4), 'gt_boxes'), tf.placeholder(tf.float32, (None, 4), 'gt_boxes'),
tf.placeholder(tf.int64, (None,), 'gt_labels')]) # all > 0 tf.placeholder(tf.int64, (None,), 'gt_labels')]) # all > 0
if config.MODE_MASK: if cfg.MODE_MASK:
ret.append( ret.append(
tf.placeholder(tf.uint8, (None, None, None), 'gt_masks') tf.placeholder(tf.uint8, (None, None, None), 'gt_masks')
) # NR_GT x height x width ) # NR_GT x height x width
return ret return ret
def build_graph(self, *inputs): def build_graph(self, *inputs):
num_fpn_level = len(config.ANCHOR_STRIDES_FPN) num_fpn_level = len(cfg.FPN.ANCHOR_STRIDES)
assert len(config.ANCHOR_SIZES) == num_fpn_level assert len(cfg.RPN.ANCHOR_SIZES) == num_fpn_level
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
image = inputs[0] image = inputs[0]
input_anchors = inputs[1: 1 + 2 * num_fpn_level] input_anchors = inputs[1: 1 + 2 * num_fpn_level]
multilevel_anchor_labels = input_anchors[0::2] multilevel_anchor_labels = input_anchors[0::2]
multilevel_anchor_boxes = input_anchors[1::2] multilevel_anchor_boxes = input_anchors[1::2]
gt_boxes, gt_labels = inputs[11], inputs[12] gt_boxes, gt_labels = inputs[11], inputs[12]
if config.MODE_MASK: if cfg.MODE_MASK:
gt_masks = inputs[-1] gt_masks = inputs[-1]
image = self.preprocess(image) # 1CHW image = self.preprocess(image) # 1CHW
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK) c2345 = resnet_fpn_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK)
p23456 = fpn_model('fpn', c2345) p23456 = fpn_model('fpn', c2345)
# Images are padded for p5, which are too large for p2-p4. # Images are padded for p5, which are too large for p2-p4.
# This seems to have no effect on mAP. # This seems to have no effect on mAP.
for i, stride in enumerate(config.ANCHOR_STRIDES_FPN[:3]): for i, stride in enumerate(cfg.FPN.ANCHOR_STRIDES[:3]):
pi = p23456[i] pi = p23456[i]
target_shape = tf.to_int32(tf.ceil(tf.to_float(image_shape2d) * (1.0 / stride))) target_shape = tf.to_int32(tf.ceil(tf.to_float(image_shape2d) * (1.0 / stride)))
p23456[i] = tf.slice(pi, [0, 0, 0, 0], p23456[i] = tf.slice(pi, [0, 0, 0, 0],
...@@ -328,7 +327,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -328,7 +327,7 @@ class ResNetFPNModel(DetectionModel):
rpn_loss_collection = [] rpn_loss_collection = []
for lvl in range(num_fpn_level): for lvl in range(num_fpn_level):
rpn_label_logits, rpn_box_logits = rpn_head( rpn_label_logits, rpn_box_logits = rpn_head(
'rpn', p23456[lvl], config.FPN_NUM_CHANNEL, len(config.ANCHOR_RATIOS)) 'rpn', p23456[lvl], cfg.FPN.NUM_CHANNEL, len(cfg.RPN.ANCHOR_RATIOS))
with tf.name_scope('FPN_lvl{}'.format(lvl + 2)): with tf.name_scope('FPN_lvl{}'.format(lvl + 2)):
anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2)) anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2))
anchors, anchor_labels, anchor_boxes = \ anchors, anchor_labels, anchor_boxes = \
...@@ -341,7 +340,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -341,7 +340,7 @@ class ResNetFPNModel(DetectionModel):
tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(pred_boxes_decoded, [-1, 4]),
tf.reshape(rpn_label_logits, [-1]), tf.reshape(rpn_label_logits, [-1]),
image_shape2d, image_shape2d,
config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) cfg.RPN.TRAIN_FPN_NMS_TOPK if is_training else cfg.RPN.TEST_FPN_NMS_TOPK)
multilevel_proposals.append((proposal_boxes, proposal_scores)) multilevel_proposals.append((proposal_boxes, proposal_scores))
if is_training: if is_training:
label_loss, box_loss = rpn_losses( label_loss, box_loss = rpn_losses(
...@@ -353,7 +352,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -353,7 +352,7 @@ class ResNetFPNModel(DetectionModel):
proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4 proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4
proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n
proposal_topk = tf.minimum(tf.size(proposal_scores), proposal_topk = tf.minimum(tf.size(proposal_scores),
config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) cfg.RPN.TRAIN_FPN_NMS_TOPK if is_training else cfg.RPN.TEST_FPN_NMS_TOPK)
proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False) proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False)
proposal_boxes = tf.gather(proposal_boxes, topk_indices) proposal_boxes = tf.gather(proposal_boxes, topk_indices)
...@@ -366,9 +365,9 @@ class ResNetFPNModel(DetectionModel): ...@@ -366,9 +365,9 @@ class ResNetFPNModel(DetectionModel):
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7) roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)
fastrcnn_head_func = getattr(model, config.FPN_FASTRCNN_HEAD_FUNC) fastrcnn_head_func = getattr(model, cfg.FPN.FRCNN_HEAD_FUNC)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func( fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func(
'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS) 'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS)
if is_training: if is_training:
# rpn loss is already defined above # rpn loss is already defined above
...@@ -388,13 +387,13 @@ class ResNetFPNModel(DetectionModel): ...@@ -388,13 +387,13 @@ class ResNetFPNModel(DetectionModel):
image, rcnn_labels, fg_sampled_boxes, image, rcnn_labels, fg_sampled_boxes,
matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)
if config.MODE_MASK: if cfg.MODE_MASK:
# maskrcnn loss # maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
roi_feature_maskrcnn = multilevel_roi_align( roi_feature_maskrcnn = multilevel_roi_align(
p23456[:4], fg_sampled_boxes, 14) p23456[:4], fg_sampled_boxes, 14)
mask_logits = maskrcnn_upXconv_head( mask_logits = maskrcnn_upXconv_head(
'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CLASS, 4) # #fg x #cat x 28 x 28
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks, 1), tf.expand_dims(gt_masks, 1),
...@@ -408,7 +407,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -408,7 +407,7 @@ class ResNetFPNModel(DetectionModel):
wd_cost = regularize_cost( wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W', '(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(1e-4), name='wd_cost') l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
total_cost = tf.add_n(rpn_loss_collection + [ total_cost = tf.add_n(rpn_loss_collection + [
fastrcnn_label_loss, fastrcnn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss,
...@@ -419,11 +418,11 @@ class ResNetFPNModel(DetectionModel): ...@@ -419,11 +418,11 @@ class ResNetFPNModel(DetectionModel):
else: else:
final_boxes, final_labels = self.fastrcnn_inference( final_boxes, final_labels = self.fastrcnn_inference(
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if config.MODE_MASK: if cfg.MODE_MASK:
# Cascade inference needs roi transform with refined boxes. # Cascade inference needs roi transform with refined boxes.
roi_feature_maskrcnn = multilevel_roi_align(p23456[:4], final_boxes, 14) roi_feature_maskrcnn = multilevel_roi_align(p23456[:4], final_boxes, 14)
mask_logits = maskrcnn_upXconv_head( mask_logits = maskrcnn_upXconv_head(
'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CLASS, 4) # #fg x #cat x 28 x 28
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28 final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28
tf.sigmoid(final_mask_logits, name='final_masks') tf.sigmoid(final_mask_logits, name='final_masks')
...@@ -537,19 +536,19 @@ def init_config(): ...@@ -537,19 +536,19 @@ def init_config():
""" """
Initialize config for training. Initialize config for training.
""" """
if config.TRAINER == 'horovod': if cfg.TRAINER == 'horovod':
ngpu = hvd.size() ngpu = hvd.size()
else: else:
ngpu = get_num_gpu() ngpu = get_num_gpu()
assert ngpu % 8 == 0 or 8 % ngpu == 0, ngpu assert ngpu % 8 == 0 or 8 % ngpu == 0, ngpu
if config.NUM_GPUS is None: if cfg.TRAIN.NUM_GPUS is None:
config.NUM_GPUS = ngpu cfg.TRAIN.NUM_GPUS = ngpu
else: else:
if config.TRAINER == 'horovod': if cfg.TRAINER == 'horovod':
assert config.NUM_GPUS == ngpu assert cfg.TRAIN.NUM_GPUS == ngpu
else: else:
assert config.NUM_GPUS <= ngpu assert cfg.TRAIN.NUM_GPUS <= ngpu
print_config() logger.info("Config: ------------------------------------------\n" + str(cfg))
if __name__ == '__main__': if __name__ == '__main__':
...@@ -569,22 +568,22 @@ if __name__ == '__main__': ...@@ -569,22 +568,22 @@ if __name__ == '__main__':
logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky.") logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky.")
args = parser.parse_args() args = parser.parse_args()
write_config_from_args(args.config) cfg.update_args(args.config)
MODEL = ResNetFPNModel() if config.MODE_FPN else ResNetC4Model() MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model()
if args.visualize or args.evaluate or args.predict: if args.visualize or args.evaluate or args.predict:
# autotune is too slow for inference # autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0' os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
assert args.load assert args.load
print_config() logger.info("Config: ------------------------------------------\n" + str(cfg))
if args.predict or args.visualize: if args.predict or args.visualize:
config.RESULT_SCORE_THRESH = config.RESULT_SCORE_THRESH_VIS cfg.TEST.RESULT_SCORE_THRESH = cfg.TEST.RESULT_SCORE_THRESH_VIS
if args.visualize: if args.visualize:
assert not config.MODE_FPN, "FPN visualize is not supported!" assert not cfg.MODE_FPN, "FPN visualize is not supported!"
visualize(args.load) visualize(args.load)
else: else:
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
...@@ -596,11 +595,11 @@ if __name__ == '__main__': ...@@ -596,11 +595,11 @@ if __name__ == '__main__':
assert args.evaluate.endswith('.json'), args.evaluate assert args.evaluate.endswith('.json'), args.evaluate
offline_evaluate(pred, args.evaluate) offline_evaluate(pred, args.evaluate)
elif args.predict: elif args.predict:
COCODetection(config.BASEDIR, 'val2014') # Only to load the class names into caches COCODetection(cfg.DATA.BASEDIR, 'val2014') # Only to load the class names into caches
predict(pred, args.predict) predict(pred, args.predict)
else: else:
os.environ['TF_AUTOTUNE_THRESHOLD'] = '1' os.environ['TF_AUTOTUNE_THRESHOLD'] = '1'
is_horovod = config.TRAINER == 'horovod' is_horovod = cfg.TRAINER == 'horovod'
if is_horovod: if is_horovod:
hvd.init() hvd.init()
logger.info("Horovod Rank={}, Size={}".format(hvd.rank(), hvd.size())) logger.info("Horovod Rank={}, Size={}".format(hvd.rank(), hvd.size()))
...@@ -611,17 +610,17 @@ if __name__ == '__main__': ...@@ -611,17 +610,17 @@ if __name__ == '__main__':
logger.set_logger_dir(args.logdir, 'd') logger.set_logger_dir(args.logdir, 'd')
init_config() init_config()
factor = 8. / config.NUM_GPUS factor = 8. / cfg.TRAIN.NUM_GPUS
stepnum = config.STEPS_PER_EPOCH stepnum = cfg.TRAIN.STEPS_PER_EPOCH
# warmup is step based, lr is epoch based # warmup is step based, lr is epoch based
warmup_schedule = [(0, config.BASE_LR / 3), (config.WARMUP * factor, config.BASE_LR)] warmup_schedule = [(0, cfg.TRAIN.BASE_LR / 3), (cfg.TRAIN.WARMUP * factor, cfg.TRAIN.BASE_LR)]
warmup_end_epoch = config.WARMUP * factor * 1. / stepnum warmup_end_epoch = cfg.TRAIN.WARMUP * factor * 1. / stepnum
lr_schedule = [(int(np.ceil(warmup_end_epoch)), warmup_schedule[-1][1])] lr_schedule = [(int(np.ceil(warmup_end_epoch)), warmup_schedule[-1][1])]
for idx, steps in enumerate(config.LR_SCHEDULE[:-1]): for idx, steps in enumerate(cfg.TRAIN.LR_SCHEDULE[:-1]):
mult = 0.1 ** (idx + 1) mult = 0.1 ** (idx + 1)
lr_schedule.append( lr_schedule.append(
(steps * factor // stepnum, config.BASE_LR * mult)) (steps * factor // stepnum, cfg.TRAIN.BASE_LR * mult))
logger.info("Warm Up Schedule (steps, value): " + str(warmup_schedule)) logger.info("Warm Up Schedule (steps, value): " + str(warmup_schedule))
logger.info("LR Schedule (epochs, value): " + str(lr_schedule)) logger.info("LR Schedule (epochs, value): " + str(lr_schedule))
...@@ -641,12 +640,12 @@ if __name__ == '__main__': ...@@ -641,12 +640,12 @@ if __name__ == '__main__':
if not is_horovod: if not is_horovod:
callbacks.append(GPUUtilizationTracker()) callbacks.append(GPUUtilizationTracker())
cfg = TrainConfig( traincfg = TrainConfig(
model=MODEL, model=MODEL,
data=QueueInput(get_train_dataflow()), data=QueueInput(get_train_dataflow()),
callbacks=callbacks, callbacks=callbacks,
steps_per_epoch=stepnum, steps_per_epoch=stepnum,
max_epoch=config.LR_SCHEDULE[-1] * factor // stepnum, max_epoch=cfg.TRAIN.LR_SCHEDULE[-1] * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None, session_init=get_model_loader(args.load) if args.load else None,
) )
if is_horovod: if is_horovod:
...@@ -654,5 +653,5 @@ if __name__ == '__main__': ...@@ -654,5 +653,5 @@ if __name__ == '__main__':
trainer = HorovodTrainer() trainer = HorovodTrainer()
else: else:
# nccl mode has better speed than cpu mode # nccl mode has better speed than cpu mode
trainer = SyncMultiGPUTrainerReplicated(config.NUM_GPUS, mode='nccl') trainer = SyncMultiGPUTrainerReplicated(cfg.TRAIN.NUM_GPUS, mode='nccl')
launch_train_with_config(cfg, trainer) launch_train_with_config(traincfg, trainer)
...@@ -8,7 +8,7 @@ from tensorpack.utils import viz ...@@ -8,7 +8,7 @@ from tensorpack.utils import viz
from tensorpack.utils.palette import PALETTE_RGB from tensorpack.utils.palette import PALETTE_RGB
from utils.np_box_ops import iou as np_iou from utils.np_box_ops import iou as np_iou
import config from config import config as cfg
def draw_annotation(img, boxes, klass, is_crowd=None): def draw_annotation(img, boxes, klass, is_crowd=None):
...@@ -17,13 +17,13 @@ def draw_annotation(img, boxes, klass, is_crowd=None): ...@@ -17,13 +17,13 @@ def draw_annotation(img, boxes, klass, is_crowd=None):
if is_crowd is not None: if is_crowd is not None:
assert len(boxes) == len(is_crowd) assert len(boxes) == len(is_crowd)
for cls, crd in zip(klass, is_crowd): for cls, crd in zip(klass, is_crowd):
clsname = config.CLASS_NAMES[cls] clsname = cfg.DATA.CLASS_NAMES[cls]
if crd == 1: if crd == 1:
clsname += ';Crowd' clsname += ';Crowd'
labels.append(clsname) labels.append(clsname)
else: else:
for cls in klass: for cls in klass:
labels.append(config.CLASS_NAMES[cls]) labels.append(cfg.DATA.CLASS_NAMES[cls])
img = viz.draw_boxes(img, boxes, labels) img = viz.draw_boxes(img, boxes, labels)
return img return img
...@@ -57,7 +57,7 @@ def draw_predictions(img, boxes, scores): ...@@ -57,7 +57,7 @@ def draw_predictions(img, boxes, scores):
return img return img
labels = scores.argmax(axis=1) labels = scores.argmax(axis=1)
scores = scores.max(axis=1) scores = scores.max(axis=1)
tags = ["{},{:.2f}".format(config.CLASS_NAMES[lb], score) for lb, score in zip(labels, scores)] tags = ["{},{:.2f}".format(cfg.DATA.CLASS_NAMES[lb], score) for lb, score in zip(labels, scores)]
return viz.draw_boxes(img, boxes, tags) return viz.draw_boxes(img, boxes, tags)
...@@ -72,7 +72,7 @@ def draw_final_outputs(img, results): ...@@ -72,7 +72,7 @@ def draw_final_outputs(img, results):
tags = [] tags = []
for r in results: for r in results:
tags.append( tags.append(
"{},{:.2f}".format(config.CLASS_NAMES[r.class_id], r.score)) "{},{:.2f}".format(cfg.DATA.CLASS_NAMES[r.class_id], r.score))
boxes = np.asarray([r.box for r in results]) boxes = np.asarray([r.box for r in results])
ret = viz.draw_boxes(img, boxes, tags) ret = viz.draw_boxes(img, boxes, tags)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment