Commit 791c7b45 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] split train.py & predict.py (#1163)

parent 9b1b5f29
...@@ -63,12 +63,12 @@ Some reasonable configurations are listed in the table below. ...@@ -63,12 +63,12 @@ Some reasonable configurations are listed in the table below.
To predict on an image (needs DISPLAY to show the outputs): To predict on an image (needs DISPLAY to show the outputs):
``` ```
./train.py --predict input1.jpg input2.jpg --load /path/to/Trained-Model-Checkpoint --config SAME-AS-TRAINING ./predict.py --predict input1.jpg input2.jpg --load /path/to/Trained-Model-Checkpoint --config SAME-AS-TRAINING
``` ```
To evaluate the performance of a model on COCO: To evaluate the performance of a model on COCO:
``` ```
./train.py --evaluate output.json --load /path/to/Trained-Model-Checkpoint \ ./predict.py --evaluate output.json --load /path/to/Trained-Model-Checkpoint \
--config SAME-AS-TRAINING --config SAME-AS-TRAINING
``` ```
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# File: backbone.py # File: backbone.py
import numpy as np import numpy as np
from contextlib import ExitStack, contextmanager
import tensorflow as tf import tensorflow as tf
from contextlib import ExitStack, contextmanager
from tensorpack.models import BatchNorm, Conv2D, MaxPooling, layer_register from tensorpack.models import BatchNorm, Conv2D, MaxPooling, layer_register
from tensorpack.tfutils import argscope from tensorpack.tfutils import argscope
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json
import numpy as np import numpy as np
import os import os
import tqdm import tqdm
import json
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.timer import timed_operation from tensorpack.utils.timer import timed_operation
from config import config as cfg from config import config as cfg
from dataset import DatasetSplit, DatasetRegistry from dataset import DatasetRegistry, DatasetSplit
__all__ = ['register_coco'] __all__ = ['register_coco']
...@@ -42,7 +42,7 @@ class COCODetection(DatasetSplit): ...@@ -42,7 +42,7 @@ class COCODetection(DatasetSplit):
self.name = name self.name = name
self._imgdir = os.path.realpath(os.path.join( self._imgdir = os.path.realpath(os.path.join(
basedir, self._INSTANCE_TO_BASEDIR.get(name, name))) basedir, self._INSTANCE_TO_BASEDIR.get(name, name)))
assert os.path.isdir(self._imgdir), self._imgdir assert os.path.isdir(self._imgdir), "{} is not a directory!".format(self._imgdir)
annotation_file = os.path.join( annotation_file = os.path.join(
basedir, 'annotations/instances_{}.json'.format(name)) basedir, 'annotations/instances_{}.json'.format(name))
assert os.path.isfile(annotation_file), annotation_file assert os.path.isfile(annotation_file), annotation_file
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import numpy as np import numpy as np
import os import os
import six
import pprint import pprint
import six
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
......
...@@ -2,24 +2,26 @@ ...@@ -2,24 +2,26 @@
# File: data.py # File: data.py
import copy import copy
import itertools
import numpy as np import numpy as np
import cv2 import cv2
import itertools
from tabulate import tabulate from tabulate import tabulate
from termcolor import colored from termcolor import colored
from tensorpack.dataflow import ( from tensorpack.dataflow import (
DataFromList, MapDataComponent, MapData, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug) DataFromList, MapData, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData,
TestDataSpeed, imgaug)
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized from tensorpack.utils.argtools import log_once, memoized
from common import ( from common import (
CustomResize, DataFromListOfDict, box_to_point8, CustomResize, DataFromListOfDict, box_to_point8, filter_boxes_inside_shape, np_iou,
filter_boxes_inside_shape, point8_to_box, segmentation_to_mask, np_iou) point8_to_box, segmentation_to_mask)
from config import config as cfg from config import config as cfg
from dataset import DatasetRegistry from dataset import DatasetRegistry
from utils.generate_anchors import generate_anchors from utils.generate_anchors import generate_anchors
from utils.np_box_ops import area as np_area, ioa as np_ioa from utils.np_box_ops import area as np_area
from utils.np_box_ops import ioa as np_ioa
# import tensorpack.utils.viz as tpviz # import tensorpack.utils.viz as tpviz
......
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
# File: eval.py # File: eval.py
import itertools import itertools
import sys
import os
import json import json
import numpy as np import numpy as np
import os
import sys
import tensorflow as tf
from collections import namedtuple from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack from contextlib import ExitStack
import cv2 import cv2
import pycocotools.mask as cocomask import pycocotools.mask as cocomask
import tqdm import tqdm
import tensorflow as tf
from tensorpack.callbacks import Callback from tensorpack.callbacks import Callback
from tensorpack.tfutils.common import get_tf_version_tuple from tensorpack.tfutils.common import get_tf_version_tuple
...@@ -20,9 +20,9 @@ from tensorpack.utils import logger ...@@ -20,9 +20,9 @@ from tensorpack.utils import logger
from tensorpack.utils.utils import get_tqdm from tensorpack.utils.utils import get_tqdm
from common import CustomResize, clip_boxes from common import CustomResize, clip_boxes
from config import config as cfg
from data import get_eval_dataflow from data import get_eval_dataflow
from dataset import DatasetRegistry from dataset import DatasetRegistry
from config import config as cfg
try: try:
import horovod.tensorflow as hvd import horovod.tensorflow as hvd
......
...@@ -4,22 +4,23 @@ ...@@ -4,22 +4,23 @@
import tensorflow as tf import tensorflow as tf
from tensorpack import ModelDesc from tensorpack import ModelDesc
from tensorpack.models import regularize_cost, l2_regularizer, GlobalAvgPooling from tensorpack.models import GlobalAvgPooling, l2_regularizer, regularize_cost
from tensorpack.tfutils.tower import get_current_tower_context
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import get_current_tower_context
import model_frcnn import model_frcnn
import model_mrcnn import model_mrcnn
from backbone import image_preprocess, resnet_c4_backbone, resnet_conv5, resnet_fpn_backbone from backbone import image_preprocess, resnet_c4_backbone, resnet_conv5, resnet_fpn_backbone
from config import config as cfg
from data import get_all_anchors, get_all_anchors_fpn
from model_box import RPNAnchors, clip_boxes, crop_and_resize, roi_align from model_box import RPNAnchors, clip_boxes, crop_and_resize, roi_align
from model_cascade import CascadeRCNNHead from model_cascade import CascadeRCNNHead
from model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses from model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses
from model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs, fastrcnn_predictions, sample_fast_rcnn_targets from model_frcnn import (
BoxProposals, FastRCNNHead, fastrcnn_outputs, fastrcnn_predictions, sample_fast_rcnn_targets)
from model_mrcnn import maskrcnn_loss, maskrcnn_upXconv_head from model_mrcnn import maskrcnn_loss, maskrcnn_upXconv_head
from model_rpn import generate_rpn_proposals, rpn_head, rpn_losses from model_rpn import generate_rpn_proposals, rpn_head, rpn_losses
from data import get_all_anchors, get_all_anchors_fpn
from config import config as cfg
class GeneralizedRCNN(ModelDesc): class GeneralizedRCNN(ModelDesc):
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# File: model_box.py # File: model_box.py
import numpy as np import numpy as np
from collections import namedtuple
import tensorflow as tf import tensorflow as tf
from collections import namedtuple
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
......
...@@ -3,28 +3,20 @@ ...@@ -3,28 +3,20 @@
# File: train.py # File: train.py
import argparse import argparse
import itertools
import numpy as np
import os
import shutil
import cv2
import six import six
assert six.PY3, "FasterRCNN requires Python 3!" assert six.PY3, "This example requires Python 3!"
import tensorflow as tf
import tqdm
import tensorpack.utils.viz as tpviz
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import collect_env_info from tensorpack.tfutils import collect_env_info
from tensorpack.tfutils.common import get_tf_version_tuple from tensorpack.tfutils.common import get_tf_version_tuple
from generalized_rcnn import ResNetFPNModel, ResNetC4Model
from dataset import DatasetRegistry
from coco import register_coco from coco import register_coco
from config import finalize_configs, config as cfg from config import config as cfg
from data import get_eval_dataflow, get_train_dataflow from config import finalize_configs
from eval import DetectionResult, predict_image, multithread_predict_dataflow, EvalCallback from data import get_train_dataflow
from viz import draw_annotation, draw_final_outputs, draw_predictions, draw_proposal_recall from eval import EvalCallback
from generalized_rcnn import ResNetC4Model, ResNetFPNModel
try: try:
import horovod.tensorflow as hvd import horovod.tensorflow as hvd
...@@ -32,94 +24,11 @@ except ImportError: ...@@ -32,94 +24,11 @@ except ImportError:
pass pass
def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):
"""
Visualize some intermediate results (proposals, raw predictions) inside the pipeline.
"""
df = get_train_dataflow() # we don't visualize mask stuff
df.reset_state()
pred = OfflinePredictor(PredictConfig(
model=model,
session_init=get_model_loader(model_path),
input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[
'generate_{}_proposals/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'generate_{}_proposals/scores'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'fastrcnn_all_scores',
'output/boxes',
'output/scores',
'output/labels',
]))
if os.path.isdir(output_dir):
shutil.rmtree(output_dir)
utils.fs.mkdir_p(output_dir)
with tqdm.tqdm(total=nr_visualize) as pbar:
for idx, dp in itertools.islice(enumerate(df), nr_visualize):
img, gt_boxes, gt_labels = dp['image'], dp['gt_boxes'], dp['gt_labels']
rpn_boxes, rpn_scores, all_scores, \
final_boxes, final_scores, final_labels = pred(img, gt_boxes, gt_labels)
# draw groundtruth boxes
gt_viz = draw_annotation(img, gt_boxes, gt_labels)
# draw best proposals for each groundtruth, to show recall
proposal_viz, good_proposals_ind = draw_proposal_recall(img, rpn_boxes, rpn_scores, gt_boxes)
# draw the scores for the above proposals
score_viz = draw_predictions(img, rpn_boxes[good_proposals_ind], all_scores[good_proposals_ind])
results = [DetectionResult(*args) for args in
zip(final_boxes, final_scores, final_labels,
[None] * len(final_labels))]
final_viz = draw_final_outputs(img, results)
viz = tpviz.stack_patches([
gt_viz, proposal_viz,
score_viz, final_viz], 2, 2)
if os.environ.get('DISPLAY', None):
tpviz.interactive_imshow(viz)
cv2.imwrite("{}/{:03d}.png".format(output_dir, idx), viz)
pbar.update()
def do_evaluate(pred_config, output_file):
num_gpu = cfg.TRAIN.NUM_GPUS
graph_funcs = MultiTowerOfflinePredictor(
pred_config, list(range(num_gpu))).get_predictors()
for dataset in cfg.DATA.VAL:
logger.info("Evaluating {} ...".format(dataset))
dataflows = [
get_eval_dataflow(dataset, shard=k, num_shards=num_gpu)
for k in range(num_gpu)]
all_results = multithread_predict_dataflow(dataflows, graph_funcs)
output = output_file + '-' + dataset
DatasetRegistry.get(dataset).eval_inference_results(all_results, output)
def do_predict(pred_func, input_file):
img = cv2.imread(input_file, cv2.IMREAD_COLOR)
results = predict_image(img, pred_func)
final = draw_final_outputs(img, results)
viz = np.concatenate((img, final), axis=1)
cv2.imwrite("output.png", viz)
logger.info("Inference output for {} written to output.png".format(input_file))
tpviz.interactive_imshow(viz)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--load', help='load a model for evaluation or training. Can overwrite BACKBONE.WEIGHTS') parser.add_argument('--load', help='load a model to start training from. Can overwrite BACKBONE.WEIGHTS')
parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn') parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn')
parser.add_argument('--visualize', action='store_true', help='visualize intermediate results') parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py", nargs='+')
parser.add_argument('--evaluate', help="Run evaluation. "
"This argument is the path to the output json evaluation file")
parser.add_argument('--predict', help="Run prediction on a given image. "
"This argument is the path to the input image file", nargs='+')
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py",
nargs='+')
if get_tf_version_tuple() < (1, 6): if get_tf_version_tuple() < (1, 6):
# https://github.com/tensorflow/tensorflow/issues/14657 # https://github.com/tensorflow/tensorflow/issues/14657
...@@ -130,35 +39,7 @@ if __name__ == '__main__': ...@@ -130,35 +39,7 @@ if __name__ == '__main__':
cfg.update_args(args.config) cfg.update_args(args.config)
register_coco(cfg.DATA.BASEDIR) # add COCO datasets to the registry register_coco(cfg.DATA.BASEDIR) # add COCO datasets to the registry
MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model() # Setup logger ...
if args.visualize or args.evaluate or args.predict:
if not tf.test.is_gpu_available():
from tensorflow.python.framework import test_util
assert get_tf_version_tuple() >= (1, 7) and test_util.IsMklEnabled(), \
"Inference requires either GPU support or MKL support!"
assert args.load
finalize_configs(is_training=False)
if args.predict or args.visualize:
cfg.TEST.RESULT_SCORE_THRESH = cfg.TEST.RESULT_SCORE_THRESH_VIS
if args.visualize:
do_visualize(MODEL, args.load)
else:
predcfg = PredictConfig(
model=MODEL,
session_init=get_model_loader(args.load),
input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1])
if args.predict:
predictor = OfflinePredictor(predcfg)
for image_file in args.predict:
do_predict(predictor, image_file)
elif args.evaluate:
assert args.evaluate.endswith('.json'), args.evaluate
do_evaluate(predcfg, args.evaluate)
else:
is_horovod = cfg.TRAINER == 'horovod' is_horovod = cfg.TRAINER == 'horovod'
if is_horovod: if is_horovod:
hvd.init() hvd.init()
...@@ -169,8 +50,9 @@ if __name__ == '__main__': ...@@ -169,8 +50,9 @@ if __name__ == '__main__':
logger.info("Environment Information:\n" + collect_env_info()) logger.info("Environment Information:\n" + collect_env_info())
finalize_configs(is_training=True) finalize_configs(is_training=True)
stepnum = cfg.TRAIN.STEPS_PER_EPOCH
# Compute the training schedule from the number of GPUs ...
stepnum = cfg.TRAIN.STEPS_PER_EPOCH
# warmup is step based, lr is epoch based # warmup is step based, lr is epoch based
init_lr = cfg.TRAIN.WARMUP_INIT_LR * min(8. / cfg.TRAIN.NUM_GPUS, 1.) init_lr = cfg.TRAIN.WARMUP_INIT_LR * min(8. / cfg.TRAIN.NUM_GPUS, 1.)
warmup_schedule = [(0, init_lr), (cfg.TRAIN.WARMUP, cfg.TRAIN.BASE_LR)] warmup_schedule = [(0, init_lr), (cfg.TRAIN.WARMUP, cfg.TRAIN.BASE_LR)]
...@@ -189,6 +71,9 @@ if __name__ == '__main__': ...@@ -189,6 +71,9 @@ if __name__ == '__main__':
total_passes = cfg.TRAIN.LR_SCHEDULE[-1] * 8 / train_dataflow.size() total_passes = cfg.TRAIN.LR_SCHEDULE[-1] * 8 / train_dataflow.size()
logger.info("Total passes of the training set is: {:.5g}".format(total_passes)) logger.info("Total passes of the training set is: {:.5g}".format(total_passes))
# Create model and callbacks ...
MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model()
callbacks = [ callbacks = [
PeriodicCallback( PeriodicCallback(
ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1), ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1),
......
...@@ -8,8 +8,8 @@ from tensorpack.utils import viz ...@@ -8,8 +8,8 @@ from tensorpack.utils import viz
from tensorpack.utils.palette import PALETTE_RGB from tensorpack.utils.palette import PALETTE_RGB
from config import config as cfg from config import config as cfg
from utils.np_box_ops import iou as np_iou
from utils.np_box_ops import area as np_area from utils.np_box_ops import area as np_area
from utils.np_box_ops import iou as np_iou
def draw_annotation(img, boxes, klass, is_crowd=None): def draw_annotation(img, boxes, klass, is_crowd=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment