Commit e233d835 authored by Yuxin Wu's avatar Yuxin Wu

update fpn

parent 21d54280
# Faster-RCNN / Mask-RCNN on COCO # Faster-RCNN / Mask-RCNN on COCO
This example aims to provide a minimal (1.3k lines) implementation of This example aims to provide a minimal (1.3k lines) implementation of
end-to-end Faster-RCNN & Mask-RCNN (with ResNet backbones) on COCO. end-to-end Faster-RCNN & Mask-RCNN (with ResNet & FPN backbones) on COCO.
## Dependencies ## Dependencies
+ Python 3; TensorFlow >= 1.4.0 + Python 3; TensorFlow >= 1.4.0 (>=1.6.0 recommended due to a TF bug);
+ [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV. + [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV.
+ Pre-trained [ResNet model](http://models.tensorpack.com/ResNet/) from tensorpack model zoo. + Pre-trained [ResNet model](http://models.tensorpack.com/ResNet/) from tensorpack model zoo.
+ COCO data. It assumes the following directory structure: + COCO data. It assumes the following directory structure:
...@@ -61,7 +61,7 @@ MaskRCNN results contain both bbox and segm mAP. ...@@ -61,7 +61,7 @@ MaskRCNN results contain both bbox and segm mAP.
|R-101 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s| |R-101 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s|
|R-101 |512 |(800, 1333)|360k |40.8/35.1 |63h on 8 V100s| |R-101 |512 |(800, 1333)|360k |40.8/35.1 |63h on 8 V100s|
The two R-50 360k models have the same configuration __and mAP__ The two R-50 360k models have the same configuration __and mAP__
as the `R50-C4-2x` entries in as the `R50-C4-2x` entries in
[Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines). [Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines).
So far this seems to be the only open source re-implementation that can reproduce mAP in Detectron. So far this seems to be the only open source re-implementation that can reproduce mAP in Detectron.
......
...@@ -5,11 +5,11 @@ import numpy as np ...@@ -5,11 +5,11 @@ import numpy as np
# mode flags --------------------- # mode flags ---------------------
MODE_MASK = True MODE_MASK = True
MODE_FPN = False
# dataset ----------------------- # dataset -----------------------
BASEDIR = '/path/to/your/COCO/DIR' BASEDIR = '/path/to/your/COCO/DIR'
TRAIN_DATASET = ['train2014', 'valminusminival2014'] TRAIN_DATASET = ['train2014', 'valminusminival2014']
# TRAIN_DATASET = ['valminusminival2014']
VAL_DATASET = 'minival2014' # only support evaluation on single dataset VAL_DATASET = 'minival2014' # only support evaluation on single dataset
NUM_CLASS = 81 NUM_CLASS = 81
CLASS_NAMES = [] # NUM_CLASS strings. Will be populated later by coco loader CLASS_NAMES = [] # NUM_CLASS strings. Will be populated later by coco loader
...@@ -29,14 +29,14 @@ LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron ...@@ -29,14 +29,14 @@ LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# image resolution -------------------- # image resolution --------------------
SHORT_EDGE_SIZE = 800 SHORT_EDGE_SIZE = 800
MAX_SIZE = 1333 # TODO use 1344 MAX_SIZE = 1333
# alternative (worse & faster) setting: 600, 1024 # alternative (worse & faster) setting: 600, 1024
# anchors ------------------------- # anchors -------------------------
ANCHOR_STRIDE = 16 ANCHOR_STRIDE = 16
ANCHOR_STRIDES_FPN = (4, 8, 16, 32, 64) ANCHOR_STRIDES_FPN = (4, 8, 16, 32, 64)
# sqrtarea of the anchor box FPN_RESOLUTION_REQUIREMENT = 32 # image size into the backbone has to be multiple of this number
ANCHOR_SIZES = (32, 64, 128, 256, 512) ANCHOR_SIZES = (32, 64, 128, 256, 512) # sqrtarea of the anchor box
ANCHOR_RATIOS = (0.5, 1., 2.) ANCHOR_RATIOS = (0.5, 1., 2.)
NUM_ANCHOR = len(ANCHOR_SIZES) * len(ANCHOR_RATIOS) NUM_ANCHOR = len(ANCHOR_SIZES) * len(ANCHOR_RATIOS)
POSITIVE_ANCHOR_THRES = 0.7 POSITIVE_ANCHOR_THRES = 0.7
...@@ -52,6 +52,7 @@ RPN_MIN_SIZE = 0 ...@@ -52,6 +52,7 @@ RPN_MIN_SIZE = 0
RPN_PROPOSAL_NMS_THRESH = 0.7 RPN_PROPOSAL_NMS_THRESH = 0.7
TRAIN_PRE_NMS_TOPK = 12000 TRAIN_PRE_NMS_TOPK = 12000
TRAIN_POST_NMS_TOPK = 2000 TRAIN_POST_NMS_TOPK = 2000
TRAIN_FPN_NMS_TOPK = 2000
# boxes overlapping crowd will be ignored. # boxes overlapping crowd will be ignored.
CROWD_OVERLAP_THRES = 0.7 CROWD_OVERLAP_THRES = 0.7
...@@ -62,19 +63,16 @@ FASTRCNN_FG_THRESH = 0.5 ...@@ -62,19 +63,16 @@ FASTRCNN_FG_THRESH = 0.5
# fg ratio in a ROI batch # fg ratio in a ROI batch
FASTRCNN_FG_RATIO = 0.25 FASTRCNN_FG_RATIO = 0.25
# modeling -------------------------
FPN_NUM_CHANNEL = 256
FASTRCNN_FC_HEAD_DIM = 1024
MASKRCNN_HEAD_DIM = 256
# testing ----------------------- # testing -----------------------
TEST_PRE_NMS_TOPK = 6000 TEST_PRE_NMS_TOPK = 6000
TEST_POST_NMS_TOPK = 1000 # if you encounter OOM in inference, set this to a smaller number TEST_POST_NMS_TOPK = 1000 # if you encounter OOM in inference, set this to a smaller number
TEST_FPN_NMS_TOPK = 1000
FASTRCNN_NMS_THRESH = 0.5 FASTRCNN_NMS_THRESH = 0.5
RESULT_SCORE_THRESH = 0.05 RESULT_SCORE_THRESH = 0.05
RESULT_SCORE_THRESH_VIS = 0.3 # only visualize confident results RESULT_SCORE_THRESH_VIS = 0.3 # only visualize confident results
RESULTS_PER_IM = 100 RESULTS_PER_IM = 100
# TODO Not Functioning. Don't USE
MODE_FPN = True
FPN_NUM_CHANNEL = 256
MASKRCNN_HEAD_DIM = 256
FASTRCNN_FC_HEAD_DIM = 1024
FPN_RESOLUTION_REQUIREMENT = 32
TRAIN_FPN_NMS_TOPK = 2000
TEST_FPN_NMS_TOPK = 1000
...@@ -8,7 +8,7 @@ import itertools ...@@ -8,7 +8,7 @@ import itertools
from tensorpack.utils.argtools import memoized, log_once from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import ( from tensorpack.dataflow import (
imgaug, TestDataSpeed, PrefetchDataZMQ, MapData, MultiProcessMapDataZMQ, imgaug, TestDataSpeed, PrefetchDataZMQ, MultiProcessMapDataZMQ,
MapDataComponent, DataFromList) MapDataComponent, DataFromList)
# import tensorpack.utils.viz as tpviz # import tensorpack.utils.viz as tpviz
......
...@@ -12,7 +12,6 @@ from tensorpack.models import ( ...@@ -12,7 +12,6 @@ from tensorpack.models import (
Conv2D, FullyConnected, MaxPooling, Conv2D, FullyConnected, MaxPooling,
layer_register, Conv2DTranspose, FixedUnPooling) layer_register, Conv2DTranspose, FixedUnPooling)
from tensorpack.utils import logger
from utils.box_ops import pairwise_iou from utils.box_ops import pairwise_iou
from utils.box_ops import area as tf_area from utils.box_ops import area as tf_area
import config import config
...@@ -90,7 +89,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits): ...@@ -90,7 +89,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
valid_label_prob > th, valid_label_prob > th,
tf.equal(valid_prediction, valid_anchor_labels)), tf.equal(valid_prediction, valid_anchor_labels)),
dtype=tf.int32) dtype=tf.int32)
placeholder = 0.5 # TODO A small value will make summaries appear lower. placeholder = 0.5 # A small value will make summaries appear lower.
recall = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos)) recall = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos))
recall = tf.where(tf.equal(nr_pos, 0), placeholder, recall, name='recall_th{}'.format(th)) recall = tf.where(tf.equal(nr_pos, 0), placeholder, recall, name='recall_th{}'.format(th))
precision = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos_prediction)) precision = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos_prediction))
...@@ -99,7 +98,9 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits): ...@@ -99,7 +98,9 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
summaries.extend([precision, recall]) summaries.extend([precision, recall])
add_moving_summary(*summaries) add_moving_summary(*summaries)
placeholder = 0. # Per-level loss summaries in FPN may appear lower. But the sum should be OK. # Per-level loss summaries in FPN may appear lower due to the use of a small placeholder.
# But the total loss is still the same.
placeholder = 0.
label_loss = tf.nn.sigmoid_cross_entropy_with_logits( label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits) labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits)
label_loss = tf.reduce_sum(label_loss) * (1. / config.RPN_BATCH_PER_IM) label_loss = tf.reduce_sum(label_loss) * (1. / config.RPN_BATCH_PER_IM)
...@@ -601,19 +602,18 @@ def fpn_model(features): ...@@ -601,19 +602,18 @@ def fpn_model(features):
num_channel = config.FPN_NUM_CHANNEL num_channel = config.FPN_NUM_CHANNEL
def upsample2x(name, x): def upsample2x(name, x):
# TODO may not be optimal in speed or math
logger.info("Unpool 1111 ...")
return FixedUnPooling( return FixedUnPooling(
name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'), name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'),
data_format='channels_first') data_format='channels_first')
with tf.name_scope(name): # tf.image.resize is, again, not aligned.
logger.info("Nearest neighbor") # with tf.name_scope(name):
shape2d = tf.shape(x)[2:] # logger.info("Nearest neighbor")
x = tf.transpose(x, [0, 2, 3, 1]) # shape2d = tf.shape(x)[2:]
x = tf.image.resize_nearest_neighbor(x, shape2d * 2, align_corners=True) # x = tf.transpose(x, [0, 2, 3, 1])
x = tf.transpose(x, [0, 3, 1, 2]) # x = tf.image.resize_nearest_neighbor(x, shape2d * 2, align_corners=True)
return x # x = tf.transpose(x, [0, 3, 1, 2])
# return x
with argscope(Conv2D, data_format='channels_first', with argscope(Conv2D, data_format='channels_first',
nl=tf.identity, use_bias=True, nl=tf.identity, use_bias=True,
...@@ -636,6 +636,8 @@ def fpn_model(features): ...@@ -636,6 +636,8 @@ def fpn_model(features):
@under_name_scope() @under_name_scope()
def fpn_map_rois_to_levels(boxes): def fpn_map_rois_to_levels(boxes):
""" """
Assign boxes to level 2~5.
Args: Args:
boxes (nx4) boxes (nx4)
......
...@@ -19,6 +19,7 @@ from tensorpack import * ...@@ -19,6 +19,7 @@ from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
from tensorpack.tfutils.common import get_tf_version_number
import tensorpack.utils.viz as tpviz import tensorpack.utils.viz as tpviz
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
...@@ -33,8 +34,7 @@ from model import ( ...@@ -33,8 +34,7 @@ from model import (
generate_rpn_proposals, sample_fast_rcnn_targets, roi_align, generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions, fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_upXconv_head, maskrcnn_loss, maskrcnn_upXconv_head, maskrcnn_loss,
fpn_model, fpn_map_rois_to_levels, fastrcnn_2fc_head, fpn_model, fastrcnn_2fc_head, multilevel_roi_align)
multilevel_roi_align)
from data import ( from data import (
get_train_dataflow, get_eval_dataflow, get_train_dataflow, get_eval_dataflow,
get_all_anchors, get_all_anchors_fpn) get_all_anchors, get_all_anchors_fpn)
...@@ -62,6 +62,8 @@ def get_model_output_names(): ...@@ -62,6 +62,8 @@ def get_model_output_names():
def get_model(): def get_model():
if config.MODE_FPN: if config.MODE_FPN:
if get_tf_version() < 1.6:
logger.warn("FPN has chances to crash in TF<1.6, due to a TF issue.")
return ResNetFPNModel() return ResNetFPNModel()
else: else:
return ResNetC4Model() return ResNetC4Model()
...@@ -223,8 +225,12 @@ class ResNetC4Model(DetectionModel): ...@@ -223,8 +225,12 @@ class ResNetC4Model(DetectionModel):
ncls = config.NUM_CLASS ncls = config.NUM_CLASS
return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4]) return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4])
feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond( if get_tf_version_number() >= 1.6:
tf.size(boxes_on_featuremap) > 0, ff_true, ff_false) feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = ff_true()
else:
logger.warn("This example may drop support for TF < 1.6 soon.")
feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)
if is_training: if is_training:
# rpn loss # rpn loss
...@@ -434,10 +440,11 @@ class ResNetFPNModel(DetectionModel): ...@@ -434,10 +440,11 @@ class ResNetFPNModel(DetectionModel):
'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28 final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28
final_masks = tf.sigmoid(final_mask_logits, name='final_masks') tf.sigmoid(final_mask_logits, name='final_masks')
def visualize(model_path, nr_visualize=50, output_dir='output'): def visualize(model_path, nr_visualize=50, output_dir='output'):
assert not config.MODE_FPN, "FPN visualize is not supported yet!"
df = get_train_dataflow() # we don't visualize mask stuff df = get_train_dataflow() # we don't visualize mask stuff
df.reset_state() df.reset_state()
...@@ -577,7 +584,7 @@ if __name__ == '__main__': ...@@ -577,7 +584,7 @@ if __name__ == '__main__':
COCODetection(config.BASEDIR, 'val2014') # Only to load the class names into caches COCODetection(config.BASEDIR, 'val2014') # Only to load the class names into caches
predict(pred, args.predict) predict(pred, args.predict)
else: else:
logger.set_logger_dir(args.logdir, 'd') logger.set_logger_dir(args.logdir)
print_config() print_config()
factor = get_batch_factor() factor = get_batch_factor()
stepnum = config.STEPS_PER_EPOCH stepnum = config.STEPS_PER_EPOCH
...@@ -611,5 +618,5 @@ if __name__ == '__main__': ...@@ -611,5 +618,5 @@ if __name__ == '__main__':
max_epoch=config.LR_SCHEDULE[-1] * factor // stepnum, max_epoch=config.LR_SCHEDULE[-1] * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None, session_init=get_model_loader(args.load) if args.load else None,
) )
trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu()) trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu(), mode='cpu')
launch_train_with_config(cfg, trainer) launch_train_with_config(cfg, trainer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment