Commit 754e17fc authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] some renames to avoid the name of "COCO"

parent cc63dee7
### File Structure ### File Structure
This is a minimal implementation that simply contains these files: This is a minimal implementation that simply contains these files:
+ coco.py: load COCO data + dataset.py: load and evaluate COCO dataset
+ data.py: prepare data for training + data.py: prepare data for training & inference
+ common.py: common data preparation utilities + common.py: common data preparation utilities
+ basemodel.py: implement backbones + basemodel.py: implement backbones
+ model_box.py: implement box-related symbolic functions + model_box.py: implement box-related symbolic functions
+ model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast-/Mask-/Cascade-RCNN models. + model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast-/Mask-/Cascade-RCNN models.
+ train.py: main training script + train.py: main entry script
+ utils/: third-party helper functions + utils/: third-party helper functions
+ eval.py: evaluation utilities + eval.py: evaluation utilities
+ viz.py: visualization utilities + viz.py: visualization utilities
...@@ -16,9 +16,10 @@ This is a minimal implementation that simply contains these files: ...@@ -16,9 +16,10 @@ This is a minimal implementation that simply contains these files:
Data: Data:
1. It's easy to train on your own data. Just replace `COCODetection.load_many` in `data.py` by your own loader. 1. It's easy to train on your own data.
Also remember to change `DATA.NUM_CATEGORY` and `DATA.CLASS_NAMES` in the config. If your data is not in COCO format, you can just rewrite all the methods of
The current evaluation code is also COCO-specific, and you may need to change it to use your data and metrics. `DetectionDataset` following its documents in `dataset.py`.
You'll implement the logic to load your dataset and evaluate predictions.
2. You can easily add more augmentations such as rotation, but be careful how a box should be 2. You can easily add more augmentations such as rotation, but be careful how a box should be
augmented. The code now will always use the minimal axis-aligned bounding box of the 4 corners, augmented. The code now will always use the minimal axis-aligned bounding box of the 4 corners,
......
...@@ -80,14 +80,14 @@ _C.MODE_MASK = True # FasterRCNN or MaskRCNN ...@@ -80,14 +80,14 @@ _C.MODE_MASK = True # FasterRCNN or MaskRCNN
_C.MODE_FPN = False _C.MODE_FPN = False
# dataset ----------------------- # dataset -----------------------
_C.DATA.BASEDIR = '/path/to/your/COCO/DIR' _C.DATA.BASEDIR = '/path/to/your/DATA/DIR'
# All TRAIN dataset will be concatenated for training. # All TRAIN dataset will be concatenated for training.
_C.DATA.TRAIN = ['train2014', 'valminusminival2014'] # i.e. trainval35k, AKA train2017 _C.DATA.TRAIN = ['train2014', 'valminusminival2014'] # i.e. trainval35k, AKA train2017
# Each VAL dataset will be evaluated separately (instead of concatenated) # Each VAL dataset will be evaluated separately (instead of concatenated)
_C.DATA.VAL = ('minival2014', ) # AKA val2017 _C.DATA.VAL = ('minival2014', ) # AKA val2017
# This two config will be populated later by the dataset loader:
_C.DATA.NUM_CATEGORY = 0 # without the background class (e.g., 80 for COCO) _C.DATA.NUM_CATEGORY = 0 # without the background class (e.g., 80 for COCO)
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG". _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# For COCO, this list will be populated later by the COCO data loader.
# basemodel ---------------------- # basemodel ----------------------
_C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz _C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import copy import copy
import numpy as np import numpy as np
import cv2 import cv2
from tabulate import tabulate
from termcolor import colored
from tensorpack.dataflow import ( from tensorpack.dataflow import (
DataFromList, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug) DataFromList, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug)
...@@ -13,7 +15,7 @@ from tensorpack.utils.argtools import log_once, memoized ...@@ -13,7 +15,7 @@ from tensorpack.utils.argtools import log_once, memoized
from common import ( from common import (
CustomResize, DataFromListOfDict, box_to_point8, filter_boxes_inside_shape, point8_to_box, segmentation_to_mask) CustomResize, DataFromListOfDict, box_to_point8, filter_boxes_inside_shape, point8_to_box, segmentation_to_mask)
from config import config as cfg from config import config as cfg
from coco import DetectionDataset from dataset import DetectionDataset
from utils.generate_anchors import generate_anchors from utils.generate_anchors import generate_anchors
from utils.np_box_ops import area as np_area from utils.np_box_ops import area as np_area
from utils.np_box_ops import ioa as np_ioa from utils.np_box_ops import ioa as np_ioa
...@@ -46,6 +48,28 @@ class MalformedData(BaseException): ...@@ -46,6 +48,28 @@ class MalformedData(BaseException):
pass pass
def print_class_histogram(roidbs):
"""
Args:
roidbs (list[dict]): the same format as the output of `load_training_roidbs`.
"""
dataset = DetectionDataset()
hist_bins = np.arange(dataset.num_classes + 1)
# Histogram of ground-truth objects
gt_hist = np.zeros((dataset.num_classes,), dtype=np.int)
for entry in roidbs:
# filter crowd?
gt_inds = np.where(
(entry['class'] > 0) & (entry['is_crowd'] == 0))[0]
gt_classes = entry['class'][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = [[dataset.class_names[i], v] for i, v in enumerate(gt_hist)]
data.append(['total', sum([x[1] for x in data])])
table = tabulate(data, headers=['class', '#box'], tablefmt='pipe')
logger.info("Ground-Truth Boxes:\n" + colored(table, 'cyan'))
@memoized @memoized
def get_all_anchors(stride=None, sizes=None): def get_all_anchors(stride=None, sizes=None):
""" """
...@@ -281,6 +305,7 @@ def get_train_dataflow(): ...@@ -281,6 +305,7 @@ def get_train_dataflow():
""" """
roidbs = DetectionDataset().load_training_roidbs(cfg.DATA.TRAIN) roidbs = DetectionDataset().load_training_roidbs(cfg.DATA.TRAIN)
print_class_histogram(roidbs)
# Valid training images should have at least one fg box. # Valid training images should have at least one fg box.
# But this filter shall not be applied for testing. # But this filter shall not be applied for testing.
......
...@@ -5,8 +5,6 @@ import numpy as np ...@@ -5,8 +5,6 @@ import numpy as np
import os import os
import tqdm import tqdm
import json import json
from tabulate import tabulate
from termcolor import colored
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once from tensorpack.utils.argtools import log_once
...@@ -29,6 +27,9 @@ class COCODetection(object): ...@@ -29,6 +27,9 @@ class COCODetection(object):
Mapping from the incontinuous COCO category id to an id in [1, #category] Mapping from the incontinuous COCO category id to an id in [1, #category]
""" """
class_names = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] # noqa
def __init__(self, basedir, name): def __init__(self, basedir, name):
self.name = name self.name = name
self._imgdir = os.path.realpath(os.path.join( self._imgdir = os.path.realpath(os.path.join(
...@@ -182,15 +183,9 @@ class COCODetection(object): ...@@ -182,15 +183,9 @@ class COCODetection(object):
class DetectionDataset(object): class DetectionDataset(object):
""" """
A singleton to load datasets, evaluate results, and provide metadata. A singleton to load datasets, evaluate results, and provide metadata.
"""
_instance = None
def __new__(cls):
if not isinstance(cls._instance, cls):
cls._instance = object.__new__(cls)
return cls._instance
To use your own dataset that's not in COCO format, rewrite all methods of this class.
"""
def __init__(self): def __init__(self):
""" """
This function is responsible for setting the dataset-specific This function is responsible for setting the dataset-specific
...@@ -198,8 +193,7 @@ class DetectionDataset(object): ...@@ -198,8 +193,7 @@ class DetectionDataset(object):
""" """
self.num_category = cfg.DATA.NUM_CATEGORY = 80 self.num_category = cfg.DATA.NUM_CATEGORY = 80
self.num_classes = self.num_category + 1 self.num_classes = self.num_category + 1
self.class_names = cfg.DATA.CLASS_NAMES = [ self.class_names = cfg.DATA.CLASS_NAMES = ["BG"] + COCODetection.class_names
"BG", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] # noqa
assert len(self.class_names) == self.num_classes assert len(self.class_names) == self.num_classes
def load_training_roidbs(self, names): def load_training_roidbs(self, names):
...@@ -284,29 +278,16 @@ class DetectionDataset(object): ...@@ -284,29 +278,16 @@ class DetectionDataset(object):
else: else:
return {} return {}
def print_class_histogram(self, roidbs): # code for singleton:
""" _instance = None
Args:
roidbs (list[dict]): the same format as the output of `load_training_roidbs`. def __new__(cls):
""" if not isinstance(cls._instance, cls):
hist_bins = np.arange(self.num_classes + 1) cls._instance = object.__new__(cls)
return cls._instance
# Histogram of ground-truth objects
gt_hist = np.zeros((self.num_classes,), dtype=np.int)
for entry in roidbs:
# filter crowd?
gt_inds = np.where(
(entry['class'] > 0) & (entry['is_crowd'] == 0))[0]
gt_classes = entry['class'][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = [[self.class_names[i], v] for i, v in enumerate(gt_hist)]
data.append(['total', sum([x[1] for x in data])])
table = tabulate(data, headers=['class', '#box'], tablefmt='pipe')
logger.info("Ground-Truth Boxes:\n" + colored(table, 'cyan'))
if __name__ == '__main__': if __name__ == '__main__':
c = COCODetection(cfg.DATA.BASEDIR, 'train2014') c = COCODetection(cfg.DATA.BASEDIR, 'train2014')
gt_boxes = c.load(add_gt=True, add_mask=True) gt_boxes = c.load(add_gt=True, add_mask=True)
print("#Images:", len(gt_boxes)) print("#Images:", len(gt_boxes))
DetectionDataset().print_class_histogram(gt_boxes)
...@@ -54,15 +54,15 @@ def paste_mask(box, mask, shape): ...@@ -54,15 +54,15 @@ def paste_mask(box, mask, shape):
return ret return ret
def detect_one_image(img, model_func): def predict_image(img, model_func):
""" """
Run detection on one image, using the TF callable. Run detection on one image, using the TF callable.
This function should handle the preprocessing internally. This function should handle the preprocessing internally.
Args: Args:
img: an image img: an image
model_func: a callable from TF model, model_func: a callable from the TF model.
takes image and returns (boxes, probs, labels, [masks]) It takes image and returns (boxes, probs, labels, [masks])
Returns: Returns:
[DetectionResult] [DetectionResult]
...@@ -90,11 +90,12 @@ def detect_one_image(img, model_func): ...@@ -90,11 +90,12 @@ def detect_one_image(img, model_func):
return results return results
def eval_coco(df, detect_func, tqdm_bar=None): def predict_dataflow(df, model_func, tqdm_bar=None):
""" """
Args: Args:
df: a DataFlow which produces (image, image_id) df: a DataFlow which produces (image, image_id)
detect_func: a callable, takes [image] and returns [DetectionResult] model_func: a callable from the TF model.
It takes image and returns (boxes, probs, labels, [masks])
tqdm_bar: a tqdm object to be shared among multiple evaluation instances. If None, tqdm_bar: a tqdm object to be shared among multiple evaluation instances. If None,
will create a new one. will create a new one.
...@@ -110,7 +111,7 @@ def eval_coco(df, detect_func, tqdm_bar=None): ...@@ -110,7 +111,7 @@ def eval_coco(df, detect_func, tqdm_bar=None):
tqdm_bar = stack.enter_context( tqdm_bar = stack.enter_context(
tqdm.tqdm(total=df.size(), **get_tqdm_kwargs())) tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()))
for img, img_id in df: for img, img_id in df:
results = detect_func(img) results = predict_image(img, model_func)
for r in results: for r in results:
res = { res = {
'image_id': img_id, 'image_id': img_id,
...@@ -130,24 +131,24 @@ def eval_coco(df, detect_func, tqdm_bar=None): ...@@ -130,24 +131,24 @@ def eval_coco(df, detect_func, tqdm_bar=None):
return all_results return all_results
def multithread_eval_coco(dataflows, detect_funcs): def multithread_predict_dataflow(dataflows, model_funcs):
""" """
Running multiple `eval_coco` in multiple threads, and aggregate the results. Running multiple `predict_dataflow` in multiple threads, and aggregate the results.
Args: Args:
dataflows: a list of DataFlow to be used in :func:`eval_coco` dataflows: a list of DataFlow to be used in :func:`predict_dataflow`
detect_funcs: a list of callable to be used in :func:`eval_coco` model_funcs: a list of callable to be used in :func:`predict_dataflow`
Returns: Returns:
list of dict, in the format used by list of dict, in the format used by
`DetectionDataset.eval_or_save_inference_results` `DetectionDataset.eval_or_save_inference_results`
""" """
num_worker = len(dataflows) num_worker = len(dataflows)
assert len(dataflows) == len(detect_funcs) assert len(dataflows) == len(model_funcs)
with ThreadPoolExecutor(max_workers=num_worker, thread_name_prefix='EvalWorker') as executor, \ with ThreadPoolExecutor(max_workers=num_worker, thread_name_prefix='EvalWorker') as executor, \
tqdm.tqdm(total=sum([df.size() for df in dataflows])) as pbar: tqdm.tqdm(total=sum([df.size() for df in dataflows])) as pbar:
futures = [] futures = []
for dataflow, pred in zip(dataflows, detect_funcs): for dataflow, pred in zip(dataflows, model_funcs):
futures.append(executor.submit(eval_coco, dataflow, pred, pbar)) futures.append(executor.submit(predict_dataflow, dataflow, pred, pbar))
all_results = list(itertools.chain(*[fut.result() for fut in futures])) all_results = list(itertools.chain(*[fut.result() for fut in futures]))
return all_results return all_results
...@@ -22,11 +22,10 @@ from tensorpack.tfutils.summary import add_moving_summary ...@@ -22,11 +22,10 @@ from tensorpack.tfutils.summary import add_moving_summary
import model_frcnn import model_frcnn
import model_mrcnn import model_mrcnn
from basemodel import image_preprocess, resnet_c4_backbone, resnet_conv5, resnet_fpn_backbone from basemodel import image_preprocess, resnet_c4_backbone, resnet_conv5, resnet_fpn_backbone
from coco import DetectionDataset from dataset import DetectionDataset
from config import config as cfg from config import finalize_configs, config as cfg
from config import finalize_configs
from data import get_all_anchors, get_all_anchors_fpn, get_eval_dataflow, get_train_dataflow from data import get_all_anchors, get_all_anchors_fpn, get_eval_dataflow, get_train_dataflow
from eval import DetectionResult, detect_one_image, eval_coco, multithread_eval_coco from eval import DetectionResult, predict_image, predict_dataflow, multithread_predict_dataflow
from model_box import RPNAnchors, clip_boxes, crop_and_resize, roi_align from model_box import RPNAnchors, clip_boxes, crop_and_resize, roi_align
from model_cascade import CascadeRCNNHead from model_cascade import CascadeRCNNHead
from model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses from model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses
...@@ -323,7 +322,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -323,7 +322,7 @@ class ResNetFPNModel(DetectionModel):
return [] return []
def visualize(model, model_path, nr_visualize=100, output_dir='output'): def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):
""" """
Visualize some intermediate results (proposals, raw predictions) inside the pipeline. Visualize some intermediate results (proposals, raw predictions) inside the pipeline.
""" """
...@@ -375,31 +374,27 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'): ...@@ -375,31 +374,27 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'):
pbar.update() pbar.update()
def offline_evaluate(pred_config, output_file): def do_evaluate(pred_config, output_file):
num_gpu = cfg.TRAIN.NUM_GPUS num_gpu = cfg.TRAIN.NUM_GPUS
graph_funcs = MultiTowerOfflinePredictor( graph_funcs = MultiTowerOfflinePredictor(
pred_config, list(range(num_gpu))).get_predictors() pred_config, list(range(num_gpu))).get_predictors()
predictors = []
for k in range(num_gpu):
predictors.append(lambda img,
pred=graph_funcs[k]: detect_one_image(img, pred))
for dataset in cfg.DATA.VAL: for dataset in cfg.DATA.VAL:
logger.info("Evaluating {} ...".format(dataset)) logger.info("Evaluating {} ...".format(dataset))
dataflows = [ dataflows = [
get_eval_dataflow(dataset, shard=k, num_shards=num_gpu) get_eval_dataflow(dataset, shard=k, num_shards=num_gpu)
for k in range(num_gpu)] for k in range(num_gpu)]
if num_gpu > 1: if num_gpu > 1:
all_results = multithread_eval_coco(dataflows, predictors) all_results = multithread_predict_dataflow(dataflows, graph_funcs)
else: else:
all_results = eval_coco(dataflows[0], predictors[0]) all_results = predict_dataflow(dataflows[0], graph_funcs[0])
output = output_file + '-' + dataset output = output_file + '-' + dataset
DetectionDataset().eval_or_save_inference_results(all_results, dataset, output) DetectionDataset().eval_or_save_inference_results(all_results, dataset, output)
def predict(pred_func, input_file): def do_predict(pred_func, input_file):
img = cv2.imread(input_file, cv2.IMREAD_COLOR) img = cv2.imread(input_file, cv2.IMREAD_COLOR)
results = detect_one_image(img, pred_func) results = predict_image(img, pred_func)
final = draw_final_outputs(img, results) final = draw_final_outputs(img, results)
viz = np.concatenate((img, final), axis=1) viz = np.concatenate((img, final), axis=1)
cv2.imwrite("output.png", viz) cv2.imwrite("output.png", viz)
...@@ -427,7 +422,7 @@ class EvalCallback(Callback): ...@@ -427,7 +422,7 @@ class EvalCallback(Callback):
# Use two predictor threads per GPU to get better throughput # Use two predictor threads per GPU to get better throughput
self.num_predictor = num_gpu if buggy_tf else num_gpu * 2 self.num_predictor = num_gpu if buggy_tf else num_gpu * 2
self.predictors = [self._build_coco_predictor(k % num_gpu) for k in range(self.num_predictor)] self.predictors = [self._build_predictor(k % num_gpu) for k in range(self.num_predictor)]
self.dataflows = [get_eval_dataflow(self._eval_dataset, self.dataflows = [get_eval_dataflow(self._eval_dataset,
shard=k, num_shards=self.num_predictor) shard=k, num_shards=self.num_predictor)
for k in range(self.num_predictor)] for k in range(self.num_predictor)]
...@@ -436,15 +431,14 @@ class EvalCallback(Callback): ...@@ -436,15 +431,14 @@ class EvalCallback(Callback):
# Alternatively, can eval on all ranks and use allgather, but allgather sometimes hangs # Alternatively, can eval on all ranks and use allgather, but allgather sometimes hangs
self._horovod_run_eval = hvd.rank() == hvd.local_rank() self._horovod_run_eval = hvd.rank() == hvd.local_rank()
if self._horovod_run_eval: if self._horovod_run_eval:
self.predictor = self._build_coco_predictor(0) self.predictor = self._build_predictor(0)
self.dataflow = get_eval_dataflow(self._eval_dataset, self.dataflow = get_eval_dataflow(self._eval_dataset,
shard=hvd.local_rank(), num_shards=hvd.local_size()) shard=hvd.local_rank(), num_shards=hvd.local_size())
self.barrier = hvd.allreduce(tf.random_normal(shape=[1])) self.barrier = hvd.allreduce(tf.random_normal(shape=[1]))
def _build_coco_predictor(self, idx): def _build_predictor(self, idx):
graph_func = self.trainer.get_predictor(self._in_names, self._out_names, device=idx) return self.trainer.get_predictor(self._in_names, self._out_names, device=idx)
return lambda img: detect_one_image(img, graph_func)
def _before_train(self): def _before_train(self):
eval_period = cfg.TRAIN.EVAL_PERIOD eval_period = cfg.TRAIN.EVAL_PERIOD
...@@ -459,14 +453,14 @@ class EvalCallback(Callback): ...@@ -459,14 +453,14 @@ class EvalCallback(Callback):
def _eval(self): def _eval(self):
logdir = args.logdir logdir = args.logdir
if cfg.TRAINER == 'replicated': if cfg.TRAINER == 'replicated':
all_results = multithread_eval_coco(self.dataflows, self.predictors) all_results = multithread_predict_dataflow(self.dataflows, self.predictors)
else: else:
filenames = [os.path.join( filenames = [os.path.join(
logdir, 'outputs{}-part{}.json'.format(self.global_step, rank) logdir, 'outputs{}-part{}.json'.format(self.global_step, rank)
) for rank in range(hvd.local_size())] ) for rank in range(hvd.local_size())]
if self._horovod_run_eval: if self._horovod_run_eval:
local_results = eval_coco(self.dataflow, self.predictor) local_results = predict_dataflow(self.dataflow, self.predictor)
fname = filenames[hvd.local_rank()] fname = filenames[hvd.local_rank()]
with open(fname, 'w') as f: with open(fname, 'w') as f:
json.dump(local_results, f) json.dump(local_results, f)
...@@ -499,7 +493,7 @@ if __name__ == '__main__': ...@@ -499,7 +493,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load a model for evaluation or training. Can overwrite BACKBONE.WEIGHTS') parser.add_argument('--load', help='load a model for evaluation or training. Can overwrite BACKBONE.WEIGHTS')
parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn') parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn')
parser.add_argument('--visualize', action='store_true', help='visualize intermediate results') parser.add_argument('--visualize', action='store_true', help='visualize intermediate results')
parser.add_argument('--evaluate', help="Run evaluation on COCO. " parser.add_argument('--evaluate', help="Run evaluation. "
"This argument is the path to the output json evaluation file") "This argument is the path to the output json evaluation file")
parser.add_argument('--predict', help="Run prediction on a given image. " parser.add_argument('--predict', help="Run prediction on a given image. "
"This argument is the path to the input image file") "This argument is the path to the input image file")
...@@ -526,7 +520,7 @@ if __name__ == '__main__': ...@@ -526,7 +520,7 @@ if __name__ == '__main__':
cfg.TEST.RESULT_SCORE_THRESH = cfg.TEST.RESULT_SCORE_THRESH_VIS cfg.TEST.RESULT_SCORE_THRESH = cfg.TEST.RESULT_SCORE_THRESH_VIS
if args.visualize: if args.visualize:
visualize(MODEL, args.load) do_visualize(MODEL, args.load)
else: else:
predcfg = PredictConfig( predcfg = PredictConfig(
model=MODEL, model=MODEL,
...@@ -534,10 +528,10 @@ if __name__ == '__main__': ...@@ -534,10 +528,10 @@ if __name__ == '__main__':
input_names=MODEL.get_inference_tensor_names()[0], input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1]) output_names=MODEL.get_inference_tensor_names()[1])
if args.predict: if args.predict:
predict(OfflinePredictor(predcfg), args.predict) do_predict(OfflinePredictor(predcfg), args.predict)
elif args.evaluate: elif args.evaluate:
assert args.evaluate.endswith('.json'), args.evaluate assert args.evaluate.endswith('.json'), args.evaluate
offline_evaluate(predcfg, args.evaluate) do_evaluate(predcfg, args.evaluate)
else: else:
is_horovod = cfg.TRAINER == 'horovod' is_horovod = cfg.TRAINER == 'horovod'
if is_horovod: if is_horovod:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment