Commit ebeb7445 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] move all model code to modeling/ (#1163)

parent 791c7b45
......@@ -5,10 +5,10 @@ This is a minimal implementation that simply contains these files:
+ coco.py: load COCO data to the dataset interface
+ data.py: prepare data for training & inference
+ common.py: common data preparation utilities
+ backbone.py: implement backbones
+ generalized_rcnn.py: implement variants of generalized R-CNN architecture
+ model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast/Mask/Cascade R-CNN models.
+ model_box.py: implement box-related symbolic functions
+ modeling/generalized_rcnn.py: implement variants of generalized R-CNN architecture
+ modeling/backbone.py: implement backbones
+ modeling/model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast/Mask/Cascade R-CNN models.
+ modeling/model_box.py: implement box-related symbolic functions
+ train.py: main entry script
+ utils/: third-party helper functions
+ eval.py: evaluation utilities
......
......@@ -325,7 +325,7 @@ def get_train_dataflow():
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
ret = {'image': im}
# rpn anchor:
# Add rpn data to dataflow:
try:
if cfg.MODE_FPN:
multilevel_anchor_inputs = get_multilevel_rpn_anchor_input(im, boxes, is_crowd)
......@@ -333,7 +333,6 @@ def get_train_dataflow():
ret['anchor_labels_lvl{}'.format(i + 2)] = anchor_labels
ret['anchor_boxes_lvl{}'.format(i + 2)] = anchor_boxes
else:
# anchor_labels, anchor_boxes
ret['anchor_labels'], ret['anchor_boxes'] = get_rpn_anchor_input(im, boxes, is_crowd)
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
......
......@@ -9,18 +9,19 @@ from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import get_current_tower_context
import model_frcnn
import model_mrcnn
from backbone import image_preprocess, resnet_c4_backbone, resnet_conv5, resnet_fpn_backbone
from config import config as cfg
from data import get_all_anchors, get_all_anchors_fpn
from model_box import RPNAnchors, clip_boxes, crop_and_resize, roi_align
from model_cascade import CascadeRCNNHead
from model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses
from model_frcnn import (
from . import model_frcnn
from . import model_mrcnn
from .backbone import image_preprocess, resnet_c4_backbone, resnet_conv5, resnet_fpn_backbone
from .model_box import RPNAnchors, clip_boxes, crop_and_resize, roi_align
from .model_cascade import CascadeRCNNHead
from .model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses
from .model_frcnn import (
BoxProposals, FastRCNNHead, fastrcnn_outputs, fastrcnn_predictions, sample_fast_rcnn_targets)
from model_mrcnn import maskrcnn_loss, maskrcnn_upXconv_head
from model_rpn import generate_rpn_proposals, rpn_head, rpn_losses
from .model_mrcnn import maskrcnn_loss, maskrcnn_upXconv_head
from .model_rpn import generate_rpn_proposals, rpn_head, rpn_losses
class GeneralizedRCNN(ModelDesc):
......
......@@ -3,9 +3,9 @@ import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context
from config import config as cfg
from model_box import clip_boxes
from model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs
from utils.box_ops import pairwise_iou
from .model_box import clip_boxes
from .model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs
class CascadeRCNNHead(object):
......
......@@ -10,11 +10,11 @@ from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import get_current_tower_context
from backbone import GroupNorm
from config import config as cfg
from model_box import roi_align
from model_rpn import generate_rpn_proposals, rpn_losses
from utils.box_ops import area as tf_area
from .backbone import GroupNorm
from .model_box import roi_align
from .model_rpn import generate_rpn_proposals, rpn_losses
@layer_register(log_shape=True)
......
......@@ -10,11 +10,12 @@ from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized_method
from backbone import GroupNorm
from config import config as cfg
from model_box import decode_bbox_target, encode_bbox_target
from utils.box_ops import pairwise_iou
from .model_box import decode_bbox_target, encode_bbox_target
from .backbone import GroupNorm
@under_name_scope()
def proposal_metrics(iou):
......
......@@ -8,7 +8,7 @@ from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from backbone import GroupNorm
from .backbone import GroupNorm
from config import config as cfg
......
......@@ -8,7 +8,7 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name
from tensorpack.tfutils.summary import add_moving_summary
from config import config as cfg
from model_box import clip_boxes
from .model_box import clip_boxes
@layer_register(log_shape=True)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import itertools
import numpy as np
import os
import shutil
import tensorflow as tf
import cv2
import six
import tqdm
assert six.PY3, "This example requires Python 3!"
import tensorpack.utils.viz as tpviz
from tensorpack.predict import MultiTowerOfflinePredictor, OfflinePredictor, PredictConfig
from tensorpack.tfutils import get_model_loader, get_tf_version_tuple
from tensorpack.utils import fs, logger
from coco import register_coco
from config import config as cfg
from config import finalize_configs
from data import get_eval_dataflow, get_train_dataflow
from dataset import DatasetRegistry
from eval import DetectionResult, multithread_predict_dataflow, predict_image
from modeling.generalized_rcnn import ResNetC4Model, ResNetFPNModel
from viz import draw_annotation, draw_final_outputs, draw_predictions, draw_proposal_recall
def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):
"""
Visualize some intermediate results (proposals, raw predictions) inside the pipeline.
"""
df = get_train_dataflow()
df.reset_state()
pred = OfflinePredictor(PredictConfig(
model=model,
session_init=get_model_loader(model_path),
input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[
'generate_{}_proposals/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'generate_{}_proposals/scores'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'fastrcnn_all_scores',
'output/boxes',
'output/scores',
'output/labels',
]))
if os.path.isdir(output_dir):
shutil.rmtree(output_dir)
fs.mkdir_p(output_dir)
with tqdm.tqdm(total=nr_visualize) as pbar:
for idx, dp in itertools.islice(enumerate(df), nr_visualize):
img, gt_boxes, gt_labels = dp['image'], dp['gt_boxes'], dp['gt_labels']
rpn_boxes, rpn_scores, all_scores, \
final_boxes, final_scores, final_labels = pred(img, gt_boxes, gt_labels)
# draw groundtruth boxes
gt_viz = draw_annotation(img, gt_boxes, gt_labels)
# draw best proposals for each groundtruth, to show recall
proposal_viz, good_proposals_ind = draw_proposal_recall(img, rpn_boxes, rpn_scores, gt_boxes)
# draw the scores for the above proposals
score_viz = draw_predictions(img, rpn_boxes[good_proposals_ind], all_scores[good_proposals_ind])
results = [DetectionResult(*args) for args in
zip(final_boxes, final_scores, final_labels,
[None] * len(final_labels))]
final_viz = draw_final_outputs(img, results)
viz = tpviz.stack_patches([
gt_viz, proposal_viz,
score_viz, final_viz], 2, 2)
if os.environ.get('DISPLAY', None):
tpviz.interactive_imshow(viz)
cv2.imwrite("{}/{:03d}.png".format(output_dir, idx), viz)
pbar.update()
def do_evaluate(pred_config, output_file):
num_gpu = cfg.TRAIN.NUM_GPUS
graph_funcs = MultiTowerOfflinePredictor(
pred_config, list(range(num_gpu))).get_predictors()
for dataset in cfg.DATA.VAL:
logger.info("Evaluating {} ...".format(dataset))
dataflows = [
get_eval_dataflow(dataset, shard=k, num_shards=num_gpu)
for k in range(num_gpu)]
all_results = multithread_predict_dataflow(dataflows, graph_funcs)
output = output_file + '-' + dataset
DatasetRegistry.get(dataset).eval_inference_results(all_results, output)
def do_predict(pred_func, input_file):
img = cv2.imread(input_file, cv2.IMREAD_COLOR)
results = predict_image(img, pred_func)
final = draw_final_outputs(img, results)
viz = np.concatenate((img, final), axis=1)
cv2.imwrite("output.png", viz)
logger.info("Inference output for {} written to output.png".format(input_file))
tpviz.interactive_imshow(viz)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', help='load a model for evaluation.', required=True)
parser.add_argument('--visualize', action='store_true', help='visualize intermediate results')
parser.add_argument('--evaluate', help="Run evaluation. "
"This argument is the path to the output json evaluation file")
parser.add_argument('--predict', help="Run prediction on a given image. "
"This argument is the path to the input image file", nargs='+')
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py",
nargs='+')
args = parser.parse_args()
if args.config:
cfg.update_args(args.config)
register_coco(cfg.DATA.BASEDIR) # add COCO datasets to the registry
MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model()
if not tf.test.is_gpu_available():
from tensorflow.python.framework import test_util
assert get_tf_version_tuple() >= (1, 7) and test_util.IsMklEnabled(), \
"Inference requires either GPU support or MKL support!"
assert args.load
finalize_configs(is_training=False)
if args.predict or args.visualize:
cfg.TEST.RESULT_SCORE_THRESH = cfg.TEST.RESULT_SCORE_THRESH_VIS
if args.visualize:
do_visualize(MODEL, args.load)
else:
predcfg = PredictConfig(
model=MODEL,
session_init=get_model_loader(args.load),
input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1])
if args.predict:
predictor = OfflinePredictor(predcfg)
for image_file in args.predict:
do_predict(predictor, image_file)
elif args.evaluate:
assert args.evaluate.endswith('.json'), args.evaluate
do_evaluate(predcfg, args.evaluate)
......@@ -15,7 +15,7 @@ from config import config as cfg
from config import finalize_configs
from data import get_train_dataflow
from eval import EvalCallback
from generalized_rcnn import ResNetC4Model, ResNetFPNModel
from modeling.generalized_rcnn import ResNetC4Model, ResNetFPNModel
try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment