Commit 9b1b5f29 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] add DatasetSplit and DatasetRegistry for generic dataset handling

parent 8908e6d4
### File Structure ### File Structure
This is a minimal implementation that simply contains these files: This is a minimal implementation that simply contains these files:
+ dataset.py: load and evaluate COCO dataset + dataset.py: the dataset interface
+ coco.py: load COCO data to the dataset interface
+ data.py: prepare data for training & inference + data.py: prepare data for training & inference
+ common.py: common data preparation utilities + common.py: common data preparation utilities
+ backbone.py: implement backbones + backbone.py: implement backbones
+ model_box.py: implement box-related symbolic functions
+ generalized_rcnn.py: implement variants of generalized R-CNN architecture + generalized_rcnn.py: implement variants of generalized R-CNN architecture
+ model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast/Mask/Cascade R-CNN models. + model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast/Mask/Cascade R-CNN models.
+ model_box.py: implement box-related symbolic functions
+ train.py: main entry script + train.py: main entry script
+ utils/: third-party helper functions + utils/: third-party helper functions
+ eval.py: evaluation utilities + eval.py: evaluation utilities
...@@ -17,20 +18,25 @@ This is a minimal implementation that simply contains these files: ...@@ -17,20 +18,25 @@ This is a minimal implementation that simply contains these files:
Data: Data:
1. It's easy to train on your own data by changing `dataset.py`. 1. It's easy to train on your own data, by calling `DatasetRegistry.register(name, lambda: YourDatasetSplit())`,
and modify `cfg.DATA.*` accordingly.
`YourDatasetSplit` can be:
+ `COCODetection`, if your data is already in COCO format. In this case, you need to
modify `COCODetection` to change the class names and the id mapping.
+ Your own class, if your data is not in COCO format.
You need to write a subclass of `DatasetSplit`, similar to `COCODetection`.
In this class you'll implement the logic to load your dataset and evaluate predictions.
The documentation is in the docstring of `DatasetSplit.
+ If your data is in COCO format, modify `COCODetection` 1. If you load a COCO-trained model on a different dataset, you may see error messages
to change the class names and the id mapping. complaining about unmatched number of categories for certain weights in the checkpoint.
+ If your data is not in COCO format, ignore `COCODetection` completely and You can either remove those weights in checkpoint, or rename them in the model.
rewrite all the methods of See [tensorpack tutorial](https://tensorpack.readthedocs.io/tutorial/save-load.html) for more details.
`DetectionDataset` following its documents.
You'll implement the logic to load your dataset and evaluate predictions.
+ If you load a COCO-trained model on a different dataset, you'll see error messages
complaining about unmatched number of categories for certain weights in the checkpoint.
You can either remove those weights in checkpoint, or rename them in the model.
See [tensorpack tutorial](https://tensorpack.readthedocs.io/tutorial/save-load.html) for more details.
2. You can easily add more augmentations such as rotation, but be careful how a box should be 1. You can easily add more augmentations such as rotation, but be careful how a box should be
augmented. The code now will always use the minimal axis-aligned bounding box of the 4 corners, augmented. The code now will always use the minimal axis-aligned bounding box of the 4 corners,
which is probably not the optimal way. which is probably not the optimal way.
A TODO is to generate bounding box from segmentation, so more augmentations can be naturally supported. A TODO is to generate bounding box from segmentation, so more augmentations can be naturally supported.
......
# -*- coding: utf-8 -*-
import numpy as np
import os
import tqdm
import json
from tensorpack.utils import logger
from tensorpack.utils.timer import timed_operation
from config import config as cfg
from dataset import DatasetSplit, DatasetRegistry
__all__ = ['register_coco']
class COCODetection(DatasetSplit):
# handle the weird (but standard) split of train and val
_INSTANCE_TO_BASEDIR = {
'valminusminival2014': 'val2014',
'minival2014': 'val2014',
}
COCO_id_to_category_id = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 27: 25, 28: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 44: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 65: 60, 67: 61, 70: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 82: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 90: 80} # noqa
"""
Mapping from the incontinuous COCO category id to an id in [1, #category]
For your own dataset, this should usually be an identity mapping.
"""
# 80 names for COCO
class_names = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] # noqa
cfg.DATA.CLASS_NAMES = ["BG"] + class_names
def __init__(self, basedir, name):
"""
Args:
basedir (str): root to the dataset
name (str): the name of the split, e.g. "train2017"
"""
basedir = os.path.expanduser(basedir)
self.name = name
self._imgdir = os.path.realpath(os.path.join(
basedir, self._INSTANCE_TO_BASEDIR.get(name, name)))
assert os.path.isdir(self._imgdir), self._imgdir
annotation_file = os.path.join(
basedir, 'annotations/instances_{}.json'.format(name))
assert os.path.isfile(annotation_file), annotation_file
from pycocotools.coco import COCO
self.coco = COCO(annotation_file)
logger.info("Instances loaded from {}.".format(annotation_file))
# https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
def print_coco_metrics(self, json_file):
"""
Args:
json_file (str): path to the results json file in coco format
Returns:
dict: the evaluation metrics
"""
from pycocotools.cocoeval import COCOeval
ret = {}
cocoDt = self.coco.loadRes(json_file)
cocoEval = COCOeval(self.coco, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
fields = ['IoU=0.5:0.95', 'IoU=0.5', 'IoU=0.75', 'small', 'medium', 'large']
for k in range(6):
ret['mAP(bbox)/' + fields[k]] = cocoEval.stats[k]
json_obj = json.load(open(json_file))
if len(json_obj) > 0 and 'segmentation' in json_obj[0]:
cocoEval = COCOeval(self.coco, cocoDt, 'segm')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
for k in range(6):
ret['mAP(segm)/' + fields[k]] = cocoEval.stats[k]
return ret
def load(self, add_gt=True, add_mask=False):
"""
Args:
add_gt: whether to add ground truth bounding box annotations to the dicts
add_mask: whether to also add ground truth mask
Returns:
a list of dict, each has keys including:
'image_id', 'file_name',
and (if add_gt is True) 'boxes', 'class', 'is_crowd', and optionally
'segmentation'.
"""
if add_mask:
assert add_gt
with timed_operation('Load Groundtruth Boxes for {}'.format(self.name)):
img_ids = self.coco.getImgIds()
img_ids.sort()
# list of dict, each has keys: height,width,id,file_name
imgs = self.coco.loadImgs(img_ids)
for img in tqdm.tqdm(imgs):
img['image_id'] = img.pop('id')
self._use_absolute_file_name(img)
if add_gt:
self._add_detection_gt(img, add_mask)
return imgs
def _use_absolute_file_name(self, img):
"""
Change relative filename to abosolute file name.
"""
img['file_name'] = os.path.join(
self._imgdir, img['file_name'])
assert os.path.isfile(img['file_name']), img['file_name']
def _add_detection_gt(self, img, add_mask):
"""
Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection.
If add_mask is True, also add 'segmentation' in coco poly format.
"""
# ann_ids = self.coco.getAnnIds(imgIds=img['image_id'])
# objs = self.coco.loadAnns(ann_ids)
objs = self.coco.imgToAnns[img['image_id']] # equivalent but faster than the above two lines
# clean-up boxes
valid_objs = []
width = img.pop('width')
height = img.pop('height')
for objid, obj in enumerate(objs):
if obj.get('ignore', 0) == 1:
continue
x1, y1, w, h = obj['bbox']
# bbox is originally in float
# x1/y1 means upper-left corner and w/h means true w/h. This can be verified by segmentation pixels.
# But we do make an assumption here that (0.0, 0.0) is upper-left corner of the first pixel
x1 = np.clip(float(x1), 0, width)
y1 = np.clip(float(y1), 0, height)
w = np.clip(float(x1 + w), 0, width) - x1
h = np.clip(float(y1 + h), 0, height) - y1
# Require non-zero seg area and more than 1x1 box size
if obj['area'] > 1 and w > 0 and h > 0 and w * h >= 4:
obj['bbox'] = [x1, y1, x1 + w, y1 + h]
valid_objs.append(obj)
if add_mask:
segs = obj['segmentation']
if not isinstance(segs, list):
assert obj['iscrowd'] == 1
obj['segmentation'] = None
else:
valid_segs = [np.asarray(p).reshape(-1, 2).astype('float32') for p in segs if len(p) >= 6]
if len(valid_segs) == 0:
logger.error("Object {} in image {} has no valid polygons!".format(objid, img['file_name']))
elif len(valid_segs) < len(segs):
logger.warn("Object {} in image {} has invalid polygons!".format(objid, img['file_name']))
obj['segmentation'] = valid_segs
# all geometrically-valid boxes are returned
boxes = np.asarray([obj['bbox'] for obj in valid_objs], dtype='float32') # (n, 4)
cls = np.asarray([
self.COCO_id_to_category_id[obj['category_id']]
for obj in valid_objs], dtype='int32') # (n,)
is_crowd = np.asarray([obj['iscrowd'] for obj in valid_objs], dtype='int8')
# add the keys
img['boxes'] = boxes # nx4
img['class'] = cls # n, always >0
img['is_crowd'] = is_crowd # n,
if add_mask:
# also required to be float32
img['segmentation'] = [
obj['segmentation'] for obj in valid_objs]
def training_roidbs(self):
return self.load(add_gt=True, add_mask=cfg.MODE_MASK)
def inference_roidbs(self):
return self.load(add_gt=False)
def eval_inference_results(self, results, output):
continuous_id_to_COCO_id = {v: k for k, v in self.COCO_id_to_category_id.items()}
for res in results:
# convert to COCO's incontinuous category id
res['category_id'] = continuous_id_to_COCO_id[res['category_id']]
# COCO expects results in xywh format
box = res['bbox']
box[2] -= box[0]
box[3] -= box[1]
res['bbox'] = [round(float(x), 3) for x in box]
assert output is not None, "COCO evaluation requires an output file!"
with open(output, 'w') as f:
json.dump(results, f)
if len(results):
# sometimes may crash if the results are empty?
return self.print_coco_metrics(output)
else:
return {}
def register_coco(basedir):
"""
Add COCO datasets like "coco_train201x" to the registry,
so you can refer to them with names in `cfg.DATA.TRAIN/VAL`.
"""
for split in ["train2017", "val2017", "train2014", "val2014",
"valminusminival2014", "minival2014"]:
DatasetRegistry.register("coco_" + split, lambda x=split: COCODetection(basedir, x))
if __name__ == '__main__':
basedir = '~/data/coco'
c = COCODetection(basedir, 'train2014')
roidb = c.load(add_gt=True, add_mask=True)
print("#Images:", len(roidb))
...@@ -85,11 +85,11 @@ _C.MODE_FPN = False ...@@ -85,11 +85,11 @@ _C.MODE_FPN = False
# dataset ----------------------- # dataset -----------------------
_C.DATA.BASEDIR = '/path/to/your/DATA/DIR' _C.DATA.BASEDIR = '/path/to/your/DATA/DIR'
# All TRAIN dataset will be concatenated for training. # All TRAIN dataset will be concatenated for training.
_C.DATA.TRAIN = ('train2014', 'valminusminival2014') # i.e. trainval35k, AKA train2017 _C.DATA.TRAIN = ('coco_train2014', 'coco_valminusminival2014') # i.e. trainval35k, AKA train2017
# Each VAL dataset will be evaluated separately (instead of concatenated) # Each VAL dataset will be evaluated separately (instead of concatenated)
_C.DATA.VAL = ('minival2014', ) # AKA val2017 _C.DATA.VAL = ('coco_minival2014', ) # AKA val2017
# This two config will be populated later by the dataset loader: # This two config will be populated later by the dataset loader:
_C.DATA.NUM_CATEGORY = 0 # without the background class (e.g., 80 for COCO) _C.DATA.NUM_CATEGORY = 80 # without the background class (e.g., 80 for COCO)
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG". _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1] # whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1]
_C.DATA.ABSOLUTE_COORD = True _C.DATA.ABSOLUTE_COORD = True
...@@ -216,7 +216,6 @@ def finalize_configs(is_training): ...@@ -216,7 +216,6 @@ def finalize_configs(is_training):
Run some sanity checks, and populate some configs from others Run some sanity checks, and populate some configs from others
""" """
_C.freeze(False) # populate new keys now _C.freeze(False) # populate new keys now
_C.DATA.BASEDIR = os.path.expanduser(_C.DATA.BASEDIR)
if isinstance(_C.DATA.VAL, six.string_types): # support single string (the typical case) as well if isinstance(_C.DATA.VAL, six.string_types): # support single string (the typical case) as well
_C.DATA.VAL = (_C.DATA.VAL, ) _C.DATA.VAL = (_C.DATA.VAL, )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import copy import copy
import numpy as np import numpy as np
import cv2 import cv2
import itertools
from tabulate import tabulate from tabulate import tabulate
from termcolor import colored from termcolor import colored
...@@ -16,7 +17,7 @@ from common import ( ...@@ -16,7 +17,7 @@ from common import (
CustomResize, DataFromListOfDict, box_to_point8, CustomResize, DataFromListOfDict, box_to_point8,
filter_boxes_inside_shape, point8_to_box, segmentation_to_mask, np_iou) filter_boxes_inside_shape, point8_to_box, segmentation_to_mask, np_iou)
from config import config as cfg from config import config as cfg
from dataset import DetectionDataset from dataset import DatasetRegistry
from utils.generate_anchors import generate_anchors from utils.generate_anchors import generate_anchors
from utils.np_box_ops import area as np_area, ioa as np_ioa from utils.np_box_ops import area as np_area, ioa as np_ioa
...@@ -30,20 +31,20 @@ class MalformedData(BaseException): ...@@ -30,20 +31,20 @@ class MalformedData(BaseException):
def print_class_histogram(roidbs): def print_class_histogram(roidbs):
""" """
Args: Args:
roidbs (list[dict]): the same format as the output of `load_training_roidbs`. roidbs (list[dict]): the same format as the output of `training_roidbs`.
""" """
dataset = DetectionDataset() # labels are in [1, NUM_CATEGORY], hence +2 for bins
hist_bins = np.arange(dataset.num_classes + 1) hist_bins = np.arange(cfg.DATA.NUM_CATEGORY + 2)
# Histogram of ground-truth objects # Histogram of ground-truth objects
gt_hist = np.zeros((dataset.num_classes,), dtype=np.int) gt_hist = np.zeros((cfg.DATA.NUM_CATEGORY + 1,), dtype=np.int)
for entry in roidbs: for entry in roidbs:
# filter crowd? # filter crowd?
gt_inds = np.where( gt_inds = np.where(
(entry['class'] > 0) & (entry['is_crowd'] == 0))[0] (entry['class'] > 0) & (entry['is_crowd'] == 0))[0]
gt_classes = entry['class'][gt_inds] gt_classes = entry['class'][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0] gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = [[dataset.class_names[i], v] for i, v in enumerate(gt_hist)] data = [[cfg.DATA.CLASS_NAMES[i], v] for i, v in enumerate(gt_hist)]
data.append(['total', sum(x[1] for x in data)]) data.append(['total', sum(x[1] for x in data)])
# the first line is BG # the first line is BG
table = tabulate(data[1:], headers=['class', '#box'], tablefmt='pipe') table = tabulate(data[1:], headers=['class', '#box'], tablefmt='pipe')
...@@ -284,7 +285,7 @@ def get_train_dataflow(): ...@@ -284,7 +285,7 @@ def get_train_dataflow():
If MODE_MASK, gt_masks: (N, h, w) If MODE_MASK, gt_masks: (N, h, w)
""" """
roidbs = DetectionDataset().load_training_roidbs(cfg.DATA.TRAIN) roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN))
print_class_histogram(roidbs) print_class_histogram(roidbs)
# Valid training images should have at least one fg box. # Valid training images should have at least one fg box.
...@@ -387,7 +388,8 @@ def get_eval_dataflow(name, shard=0, num_shards=1): ...@@ -387,7 +388,8 @@ def get_eval_dataflow(name, shard=0, num_shards=1):
name (str): name of the dataset to evaluate name (str): name of the dataset to evaluate
shard, num_shards: to get subset of evaluation data shard, num_shards: to get subset of evaluation data
""" """
roidbs = DetectionDataset().load_inference_roidbs(name) roidbs = DatasetRegistry.get(name).inference_roidbs()
logger.info("Found {} images for inference.".format(len(roidbs)))
num_imgs = len(roidbs) num_imgs = len(roidbs)
img_per_shard = num_imgs // num_shards img_per_shard = num_imgs // num_shards
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: coco.py __all__ = ['DatasetRegistry', 'DatasetSplit']
import numpy as np
import os
import tqdm
import json
from tensorpack.utils import logger class DatasetSplit():
from tensorpack.utils.timer import timed_operation
from config import config as cfg
__all__ = ['COCODetection', 'DetectionDataset']
class COCODetection:
# handle the weird (but standard) split of train and val
_INSTANCE_TO_BASEDIR = {
'valminusminival2014': 'val2014',
'minival2014': 'val2014',
}
COCO_id_to_category_id = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 27: 25, 28: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 44: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 65: 60, 67: 61, 70: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 82: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 90: 80} # noqa
"""
Mapping from the incontinuous COCO category id to an id in [1, #category]
For your own dataset, this should usually be an identity mapping.
""" """
A class to load datasets, evaluate results for a datast split (e.g., "coco_train_2017")
# 80 names for COCO To use your own dataset that's not in COCO format, write a subclass that
class_names = [ implements the interfaces.
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] # noqa
def __init__(self, basedir, name):
basedir = os.path.expanduser(basedir)
self.name = name
self._imgdir = os.path.realpath(os.path.join(
basedir, self._INSTANCE_TO_BASEDIR.get(name, name)))
assert os.path.isdir(self._imgdir), self._imgdir
annotation_file = os.path.join(
basedir, 'annotations/instances_{}.json'.format(name))
assert os.path.isfile(annotation_file), annotation_file
from pycocotools.coco import COCO
self.coco = COCO(annotation_file)
logger.info("Instances loaded from {}.".format(annotation_file))
# https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
def print_coco_metrics(self, json_file):
"""
Args:
json_file (str): path to the results json file in coco format
Returns:
dict: the evaluation metrics
"""
from pycocotools.cocoeval import COCOeval
ret = {}
cocoDt = self.coco.loadRes(json_file)
cocoEval = COCOeval(self.coco, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
fields = ['IoU=0.5:0.95', 'IoU=0.5', 'IoU=0.75', 'small', 'medium', 'large']
for k in range(6):
ret['mAP(bbox)/' + fields[k]] = cocoEval.stats[k]
json_obj = json.load(open(json_file))
if len(json_obj) > 0 and 'segmentation' in json_obj[0]:
cocoEval = COCOeval(self.coco, cocoDt, 'segm')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
for k in range(6):
ret['mAP(segm)/' + fields[k]] = cocoEval.stats[k]
return ret
def load(self, add_gt=True, add_mask=False):
"""
Args:
add_gt: whether to add ground truth bounding box annotations to the dicts
add_mask: whether to also add ground truth mask
Returns:
a list of dict, each has keys including:
'image_id', 'file_name',
and (if add_gt is True) 'boxes', 'class', 'is_crowd', and optionally
'segmentation'.
"""
if add_mask:
assert add_gt
with timed_operation('Load Groundtruth Boxes for {}'.format(self.name)):
img_ids = self.coco.getImgIds()
img_ids.sort()
# list of dict, each has keys: height,width,id,file_name
imgs = self.coco.loadImgs(img_ids)
for img in tqdm.tqdm(imgs):
img['image_id'] = img.pop('id')
self._use_absolute_file_name(img)
if add_gt:
self._add_detection_gt(img, add_mask)
return imgs
def _use_absolute_file_name(self, img):
"""
Change relative filename to abosolute file name.
"""
img['file_name'] = os.path.join(
self._imgdir, img['file_name'])
assert os.path.isfile(img['file_name']), img['file_name']
def _add_detection_gt(self, img, add_mask):
"""
Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection.
If add_mask is True, also add 'segmentation' in coco poly format.
"""
# ann_ids = self.coco.getAnnIds(imgIds=img['image_id'])
# objs = self.coco.loadAnns(ann_ids)
objs = self.coco.imgToAnns[img['image_id']] # equivalent but faster than the above two lines
# clean-up boxes
valid_objs = []
width = img.pop('width')
height = img.pop('height')
for objid, obj in enumerate(objs):
if obj.get('ignore', 0) == 1:
continue
x1, y1, w, h = obj['bbox']
# bbox is originally in float
# x1/y1 means upper-left corner and w/h means true w/h. This can be verified by segmentation pixels.
# But we do make an assumption here that (0.0, 0.0) is upper-left corner of the first pixel
x1 = np.clip(float(x1), 0, width)
y1 = np.clip(float(y1), 0, height)
w = np.clip(float(x1 + w), 0, width) - x1
h = np.clip(float(y1 + h), 0, height) - y1
# Require non-zero seg area and more than 1x1 box size
if obj['area'] > 1 and w > 0 and h > 0 and w * h >= 4:
obj['bbox'] = [x1, y1, x1 + w, y1 + h]
valid_objs.append(obj)
if add_mask:
segs = obj['segmentation']
if not isinstance(segs, list):
assert obj['iscrowd'] == 1
obj['segmentation'] = None
else:
valid_segs = [np.asarray(p).reshape(-1, 2).astype('float32') for p in segs if len(p) >= 6]
if len(valid_segs) == 0:
logger.error("Object {} in image {} has no valid polygons!".format(objid, img['file_name']))
elif len(valid_segs) < len(segs):
logger.warn("Object {} in image {} has invalid polygons!".format(objid, img['file_name']))
obj['segmentation'] = valid_segs
# all geometrically-valid boxes are returned
boxes = np.asarray([obj['bbox'] for obj in valid_objs], dtype='float32') # (n, 4)
cls = np.asarray([
self.COCO_id_to_category_id[obj['category_id']]
for obj in valid_objs], dtype='int32') # (n,)
is_crowd = np.asarray([obj['iscrowd'] for obj in valid_objs], dtype='int8')
# add the keys
img['boxes'] = boxes # nx4
img['class'] = cls # n, always >0
img['is_crowd'] = is_crowd # n,
if add_mask:
# also required to be float32
img['segmentation'] = [
obj['segmentation'] for obj in valid_objs]
@staticmethod
def load_many(basedir, names, add_gt=True, add_mask=False):
"""
Load and merges several instance files together.
Returns the same format as :meth:`COCODetection.load`.
"""
if not isinstance(names, (list, tuple)):
names = [names]
ret = []
for n in names:
coco = COCODetection(basedir, n)
ret.extend(coco.load(add_gt, add_mask=add_mask))
return ret
class DetectionDataset:
"""
A singleton to load datasets, evaluate results, and provide metadata.
To use your own dataset that's not in COCO format, rewrite all methods of this class.
""" """
def __init__(self): def training_roidbs(self):
"""
This function is responsible for setting the dataset-specific
attributes in both cfg and self.
"""
self.num_category = cfg.DATA.NUM_CATEGORY = len(COCODetection.class_names)
self.num_classes = self.num_category + 1
self.class_names = cfg.DATA.CLASS_NAMES = ["BG"] + COCODetection.class_names
def load_training_roidbs(self, names):
""" """
Args:
names (list[str]): name of the training datasets, e.g. ['train2014', 'valminusminival2014']
Returns: Returns:
roidbs (list[dict]): roidbs (list[dict]):
...@@ -225,14 +31,10 @@ class DetectionDataset: ...@@ -225,14 +31,10 @@ class DetectionDataset:
Include this field only if training Mask R-CNN. Include this field only if training Mask R-CNN.
""" """
return COCODetection.load_many( raise NotImplementedError()
cfg.DATA.BASEDIR, names, add_gt=True, add_mask=cfg.MODE_MASK)
def load_inference_roidbs(self, name): def inference_roidbs(self):
""" """
Args:
name (str): name of one inference dataset, e.g. 'minival2014'
Returns: Returns:
roidbs (list[dict]): roidbs (list[dict]):
...@@ -242,56 +44,48 @@ class DetectionDataset: ...@@ -242,56 +44,48 @@ class DetectionDataset:
file_name (str): full path to the image file_name (str): full path to the image
image_id (str): an id for the image. The inference results will be stored with this id. image_id (str): an id for the image. The inference results will be stored with this id.
""" """
return COCODetection.load_many(cfg.DATA.BASEDIR, name, add_gt=False) raise NotImplementedError()
def eval_or_save_inference_results(self, results, dataset, output=None): def eval_inference_results(self, results, output=None):
""" """
Args: Args:
results (list[dict]): the inference results as dicts. results (list[dict]): the inference results as dicts.
Each dict corresponds to one __instance__. It contains the following keys: Each dict corresponds to one __instance__. It contains the following keys:
image_id (str): the id that matches `load_inference_roidbs`. image_id (str): the id that matches `inference_roidbs`.
category_id (int): the category prediction, in range [1, #category] category_id (int): the category prediction, in range [1, #category]
bbox (list[float]): x1, y1, x2, y2 bbox (list[float]): x1, y1, x2, y2
score (float): score (float):
segmentation: the segmentation mask in COCO's rle format. segmentation: the segmentation mask in COCO's rle format.
output (str): the output file or directory to optionally save the results to.
dataset (str): the name of the dataset to evaluate.
output (str): the output file to optionally save the results to.
Returns: Returns:
dict: the evaluation results. dict: the evaluation results.
""" """
continuous_id_to_COCO_id = {v: k for k, v in COCODetection.COCO_id_to_category_id.items()} raise NotImplementedError()
for res in results:
# convert to COCO's incontinuous category id
res['category_id'] = continuous_id_to_COCO_id[res['category_id']]
# COCO expects results in xywh format
box = res['bbox']
box[2] -= box[0]
box[3] -= box[1]
res['bbox'] = [round(float(x), 3) for x in box]
assert output is not None, "COCO evaluation requires an output file!"
with open(output, 'w') as f:
json.dump(results, f)
if len(results):
# sometimes may crash if the results are empty?
return COCODetection(cfg.DATA.BASEDIR, dataset).print_coco_metrics(output)
else:
return {}
# code for singleton: class DatasetRegistry():
_instance = None _registry = {}
def __new__(cls): @staticmethod
if not isinstance(cls._instance, cls): def register(name, func):
cls._instance = object.__new__(cls) """
return cls._instance Args:
name (str): the name of the dataset split, e.g. "coco_train2017"
func: a function which returns an instance of `DatasetSplit`
"""
assert name not in DatasetRegistry._registry, "Dataset {} was registered already!".format(name)
DatasetRegistry._registry[name] = func
@staticmethod
def get(name):
"""
Args:
name (str): the name of the dataset split, e.g. "coco_train2017"
if __name__ == '__main__': Returns:
cfg.DATA.BASEDIR = '~/data/coco' DatasetSplit
c = COCODetection(cfg.DATA.BASEDIR, 'train2014') """
roidb = c.load(add_gt=True, add_mask=True) assert name in DatasetRegistry._registry, "Dataset {} was not egistered!".format(name)
print("#Images:", len(roidb)) return DatasetRegistry._registry[name]()
...@@ -21,7 +21,7 @@ from tensorpack.utils.utils import get_tqdm ...@@ -21,7 +21,7 @@ from tensorpack.utils.utils import get_tqdm
from common import CustomResize, clip_boxes from common import CustomResize, clip_boxes
from data import get_eval_dataflow from data import get_eval_dataflow
from dataset import DetectionDataset from dataset import DatasetRegistry
from config import config as cfg from config import config as cfg
try: try:
...@@ -116,7 +116,7 @@ def predict_dataflow(df, model_func, tqdm_bar=None): ...@@ -116,7 +116,7 @@ def predict_dataflow(df, model_func, tqdm_bar=None):
Returns: Returns:
list of dict, in the format used by list of dict, in the format used by
`DetectionDataset.eval_or_save_inference_results` `DatasetSplit.eval_inference_results`
""" """
df.reset_state() df.reset_state()
all_results = [] all_results = []
...@@ -156,7 +156,7 @@ def multithread_predict_dataflow(dataflows, model_funcs): ...@@ -156,7 +156,7 @@ def multithread_predict_dataflow(dataflows, model_funcs):
Returns: Returns:
list of dict, in the format used by list of dict, in the format used by
`DetectionDataset.eval_or_save_inference_results` `DatasetSplit.eval_inference_results`
""" """
num_worker = len(model_funcs) num_worker = len(model_funcs)
assert len(dataflows) == num_worker assert len(dataflows) == num_worker
...@@ -248,8 +248,8 @@ class EvalCallback(Callback): ...@@ -248,8 +248,8 @@ class EvalCallback(Callback):
output_file = os.path.join( output_file = os.path.join(
logdir, '{}-outputs{}.json'.format(self._eval_dataset, self.global_step)) logdir, '{}-outputs{}.json'.format(self._eval_dataset, self.global_step))
scores = DetectionDataset().eval_or_save_inference_results( scores = DatasetRegistry.get(self._eval_dataset).eval_inference_results(
all_results, self._eval_dataset, output_file) all_results, output_file)
for k, v in scores.items(): for k, v in scores.items():
self.trainer.monitors.put_scalar(self._eval_dataset + '-' + k, v) self.trainer.monitors.put_scalar(self._eval_dataset + '-' + k, v)
......
...@@ -111,7 +111,7 @@ def fastrcnn_outputs(feature, num_categories, class_agnostic_regression=False): ...@@ -111,7 +111,7 @@ def fastrcnn_outputs(feature, num_categories, class_agnostic_regression=False):
Returns: Returns:
cls_logits: N x num_class classification logits cls_logits: N x num_class classification logits
reg_logits: N x num_classx4 or Nx2x4 if class agnostic reg_logits: N x num_classx4 or Nx1x4 if class agnostic
""" """
num_classes = num_categories + 1 num_classes = num_categories + 1
classification = FullyConnected( classification = FullyConnected(
......
...@@ -19,7 +19,8 @@ from tensorpack.tfutils import collect_env_info ...@@ -19,7 +19,8 @@ from tensorpack.tfutils import collect_env_info
from tensorpack.tfutils.common import get_tf_version_tuple from tensorpack.tfutils.common import get_tf_version_tuple
from generalized_rcnn import ResNetFPNModel, ResNetC4Model from generalized_rcnn import ResNetFPNModel, ResNetC4Model
from dataset import DetectionDataset from dataset import DatasetRegistry
from coco import register_coco
from config import finalize_configs, config as cfg from config import finalize_configs, config as cfg
from data import get_eval_dataflow, get_train_dataflow from data import get_eval_dataflow, get_train_dataflow
from eval import DetectionResult, predict_image, multithread_predict_dataflow, EvalCallback from eval import DetectionResult, predict_image, multithread_predict_dataflow, EvalCallback
...@@ -95,7 +96,7 @@ def do_evaluate(pred_config, output_file): ...@@ -95,7 +96,7 @@ def do_evaluate(pred_config, output_file):
for k in range(num_gpu)] for k in range(num_gpu)]
all_results = multithread_predict_dataflow(dataflows, graph_funcs) all_results = multithread_predict_dataflow(dataflows, graph_funcs)
output = output_file + '-' + dataset output = output_file + '-' + dataset
DetectionDataset().eval_or_save_inference_results(all_results, dataset, output) DatasetRegistry.get(dataset).eval_inference_results(all_results, output)
def do_predict(pred_func, input_file): def do_predict(pred_func, input_file):
...@@ -127,9 +128,9 @@ if __name__ == '__main__': ...@@ -127,9 +128,9 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
if args.config: if args.config:
cfg.update_args(args.config) cfg.update_args(args.config)
register_coco(cfg.DATA.BASEDIR) # add COCO datasets to the registry
MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model() MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model()
DetectionDataset() # initialize the config with information from our dataset
if args.visualize or args.evaluate or args.predict: if args.visualize or args.evaluate or args.predict:
if not tf.test.is_gpu_available(): if not tf.test.is_gpu_available():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment