Commit 4d041d06 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] Add class_names to metadata in DatasetRegistry

parent 8932306f
...@@ -24,7 +24,7 @@ Data: ...@@ -24,7 +24,7 @@ Data:
`YourDatasetSplit` can be: `YourDatasetSplit` can be:
+ `COCODetection`, if your data is already in COCO format. In this case, you need to + `COCODetection`, if your data is already in COCO format. In this case, you need to
modify `COCODetection` to change the class names and the id mapping. modify `dataset/coco.py` to change the class names and the id mapping.
+ Your own class, if your data is not in COCO format. + Your own class, if your data is not in COCO format.
You need to write a subclass of `DatasetSplit`, similar to `COCODetection`. You need to write a subclass of `DatasetSplit`, similar to `COCODetection`.
......
...@@ -88,7 +88,7 @@ def point8_to_box(points): ...@@ -88,7 +88,7 @@ def point8_to_box(points):
return np.concatenate((minxy, maxxy), axis=1) return np.concatenate((minxy, maxxy), axis=1)
def segmentation_to_mask(polys, height, width): def polygons_to_mask(polys, height, width):
""" """
Convert polygons to binary masks. Convert polygons to binary masks.
......
...@@ -90,9 +90,10 @@ _C.DATA.TRAIN = ('coco_train2017',) # i.e. trainval35k ...@@ -90,9 +90,10 @@ _C.DATA.TRAIN = ('coco_train2017',) # i.e. trainval35k
# Each VAL dataset will be evaluated separately (instead of concatenated) # Each VAL dataset will be evaluated separately (instead of concatenated)
_C.DATA.VAL = ('coco_val2017',) # AKA minival2014 _C.DATA.VAL = ('coco_val2017',) # AKA minival2014
# This two config will be populated later by the dataset loader: # These two configs will be populated later inside `finalize_configs`.
_C.DATA.NUM_CATEGORY = 80 # without the background class (e.g., 80 for COCO) _C.DATA.NUM_CATEGORY = -1 # without the background class (e.g., 80 for COCO)
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG". _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1] # whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1]
_C.DATA.ABSOLUTE_COORD = True _C.DATA.ABSOLUTE_COORD = True
# Number of data loading workers. # Number of data loading workers.
...@@ -228,6 +229,12 @@ def finalize_configs(is_training): ...@@ -228,6 +229,12 @@ def finalize_configs(is_training):
if isinstance(_C.DATA.TRAIN, six.string_types): # support single string if isinstance(_C.DATA.TRAIN, six.string_types): # support single string
_C.DATA.TRAIN = (_C.DATA.TRAIN, ) _C.DATA.TRAIN = (_C.DATA.TRAIN, )
# finalize dataset definitions ...
from dataset import DatasetRegistry
datasets = list(_C.DATA.TRAIN) + list(_C.DATA.VAL)
_C.DATA.CLASS_NAMES = DatasetRegistry.get_metadata(datasets[0], "class_names")
_C.DATA.NUM_CATEGORY = len(_C.DATA.CLASS_NAMES) - 1
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN', 'None'], _C.BACKBONE.NORM assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN', 'None'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN': if _C.BACKBONE.NORM != 'FreezeBN':
assert not _C.BACKBONE.FREEZE_AFFINE assert not _C.BACKBONE.FREEZE_AFFINE
......
...@@ -19,7 +19,7 @@ from modeling.model_rpn import get_all_anchors ...@@ -19,7 +19,7 @@ from modeling.model_rpn import get_all_anchors
from modeling.model_fpn import get_all_anchors_fpn from modeling.model_fpn import get_all_anchors_fpn
from common import ( from common import (
CustomResize, DataFromListOfDict, box_to_point8, CustomResize, DataFromListOfDict, box_to_point8,
filter_boxes_inside_shape, np_iou, point8_to_box, segmentation_to_mask, filter_boxes_inside_shape, np_iou, point8_to_box, polygons_to_mask,
) )
from config import config as cfg from config import config as cfg
from dataset import DatasetRegistry from dataset import DatasetRegistry
...@@ -38,6 +38,7 @@ def print_class_histogram(roidbs): ...@@ -38,6 +38,7 @@ def print_class_histogram(roidbs):
Args: Args:
roidbs (list[dict]): the same format as the output of `training_roidbs`. roidbs (list[dict]): the same format as the output of `training_roidbs`.
""" """
class_names = DatasetRegistry.get_metadata(cfg.DATA.TRAIN[0], 'class_names')
# labels are in [1, NUM_CATEGORY], hence +2 for bins # labels are in [1, NUM_CATEGORY], hence +2 for bins
hist_bins = np.arange(cfg.DATA.NUM_CATEGORY + 2) hist_bins = np.arange(cfg.DATA.NUM_CATEGORY + 2)
...@@ -49,7 +50,7 @@ def print_class_histogram(roidbs): ...@@ -49,7 +50,7 @@ def print_class_histogram(roidbs):
gt_classes = entry["class"][gt_inds] gt_classes = entry["class"][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0] gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
COL = 6 COL = 6
data = list(itertools.chain(*[[cfg.DATA.CLASS_NAMES[i], v] for i, v in enumerate(gt_hist[1:])])) data = list(itertools.chain(*[[class_names[i + 1], v] for i, v in enumerate(gt_hist[1:])]))
total_instances = sum(data[1::2]) total_instances = sum(data[1::2])
data.extend([None] * (COL - len(data) % COL)) data.extend([None] * (COL - len(data) % COL))
data.extend(["total", total_instances]) data.extend(["total", total_instances])
...@@ -83,7 +84,7 @@ class TrainingDataPreprocessor: ...@@ -83,7 +84,7 @@ class TrainingDataPreprocessor:
im = im.astype("float32") im = im.astype("float32")
height, width = im.shape[:2] height, width = im.shape[:2]
# assume floatbox as input # assume floatbox as input
assert boxes.dtype == np.float32, "Loader has to return floating point boxes!" assert boxes.dtype == np.float32, "Loader has to return float32 boxes!"
if not self.cfg.DATA.ABSOLUTE_COORD: if not self.cfg.DATA.ABSOLUTE_COORD:
boxes[:, 0::2] *= width boxes[:, 0::2] *= width
...@@ -133,7 +134,7 @@ class TrainingDataPreprocessor: ...@@ -133,7 +134,7 @@ class TrainingDataPreprocessor:
if not self.cfg.DATA.ABSOLUTE_COORD: if not self.cfg.DATA.ABSOLUTE_COORD:
polys = [p * width_height for p in polys] polys = [p * width_height for p in polys]
polys = [tfms.apply_coords(p) for p in polys] polys = [tfms.apply_coords(p) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], gt_mask_width)) masks.append(polygons_to_mask(polys, im.shape[0], gt_mask_width))
if len(masks): if len(masks):
masks = np.asarray(masks, dtype='uint8') # values in {0, 1} masks = np.asarray(masks, dtype='uint8') # values in {0, 1}
...@@ -219,6 +220,7 @@ class TrainingDataPreprocessor: ...@@ -219,6 +220,7 @@ class TrainingDataPreprocessor:
all_anchors_flatten = np.concatenate(flatten_anchors_per_level, axis=0) all_anchors_flatten = np.concatenate(flatten_anchors_per_level, axis=0)
inside_ind, inside_anchors = filter_boxes_inside_shape(all_anchors_flatten, im.shape[:2]) inside_ind, inside_anchors = filter_boxes_inside_shape(all_anchors_flatten, im.shape[:2])
anchor_labels, anchor_gt_boxes = self.get_anchor_labels( anchor_labels, anchor_gt_boxes = self.get_anchor_labels(
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1] inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1]
) )
......
...@@ -27,14 +27,6 @@ class COCODetection(DatasetSplit): ...@@ -27,14 +27,6 @@ class COCODetection(DatasetSplit):
""" """
COCO_id_to_category_id = {13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 27: 25, 28: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 44: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 65: 60, 67: 61, 70: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 82: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 90: 80} # noqa COCO_id_to_category_id = {13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 27: 25, 28: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 44: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 65: 60, 67: 61, 70: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 82: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 90: 80} # noqa
"""
80 names for COCO
For your own coco-format dataset, change this.
"""
class_names = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] # noqa
cfg.DATA.CLASS_NAMES = ["BG"] + class_names
def __init__(self, basedir, split): def __init__(self, basedir, split):
""" """
Args: Args:
...@@ -230,10 +222,18 @@ def register_coco(basedir): ...@@ -230,10 +222,18 @@ def register_coco(basedir):
Note that train2017==trainval35k==train2014+val2014-minival2014, and val2017==minival2014. Note that train2017==trainval35k==train2014+val2014-minival2014, and val2017==minival2014.
""" """
for split in ["train2017", "val2017", "train2014", "val2014",
"valminusminival2014", "minival2014"]:
DatasetRegistry.register("coco_" + split, lambda x=split: COCODetection(basedir, x))
# 80 names for COCO
# For your own coco-format dataset, change this.
class_names = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] # noqa
class_names = ["BG"] + class_names
for split in ["train2017", "val2017", "train2014", "val2014",
"valminusminival2014", "minival2014", "trainsingle"]:
name = "coco_" + split
DatasetRegistry.register(name, lambda x=split: COCODetection(basedir, x))
DatasetRegistry.register_metadata(name, 'class_names', class_names)
if __name__ == '__main__': if __name__ == '__main__':
basedir = '~/data/coco' basedir = '~/data/coco'
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from collections import defaultdict
__all__ = ['DatasetRegistry', 'DatasetSplit'] __all__ = ['DatasetRegistry', 'DatasetSplit']
...@@ -68,6 +71,7 @@ class DatasetSplit(): ...@@ -68,6 +71,7 @@ class DatasetSplit():
class DatasetRegistry(): class DatasetRegistry():
_registry = {} _registry = {}
_metadata_registry = defaultdict(dict)
@staticmethod @staticmethod
def register(name, func): def register(name, func):
...@@ -90,3 +94,25 @@ class DatasetRegistry(): ...@@ -90,3 +94,25 @@ class DatasetRegistry():
""" """
assert name in DatasetRegistry._registry, "Dataset {} was not registered!".format(name) assert name in DatasetRegistry._registry, "Dataset {} was not registered!".format(name)
return DatasetRegistry._registry[name]() return DatasetRegistry._registry[name]()
@staticmethod
def register_metadata(name, key, value):
"""
Args:
name (str): the name of the dataset split, e.g. "coco_train2017"
key: the key of the metadata, e.g., "class_names"
value: the value of the metadata
"""
DatasetRegistry._metadata_registry[name][key] = value
@staticmethod
def get_metadata(name, key):
"""
Args:
name (str): the name of the dataset split, e.g. "coco_train2017"
key: the key of the metadata, e.g., "class_names"
Returns:
value
"""
return DatasetRegistry._metadata_registry[name][key]
...@@ -10,9 +10,10 @@ from tensorpack.utils.palette import PALETTE_RGB ...@@ -10,9 +10,10 @@ from tensorpack.utils.palette import PALETTE_RGB
from config import config as cfg from config import config as cfg
from utils.np_box_ops import area as np_area from utils.np_box_ops import area as np_area
from utils.np_box_ops import iou as np_iou from utils.np_box_ops import iou as np_iou
from common import polygons_to_mask
def draw_annotation(img, boxes, klass, is_crowd=None): def draw_annotation(img, boxes, klass, polygons=None, is_crowd=None):
"""Will not modify img""" """Will not modify img"""
labels = [] labels = []
assert len(boxes) == len(klass) assert len(boxes) == len(klass)
...@@ -27,6 +28,11 @@ def draw_annotation(img, boxes, klass, is_crowd=None): ...@@ -27,6 +28,11 @@ def draw_annotation(img, boxes, klass, is_crowd=None):
for cls in klass: for cls in klass:
labels.append(cfg.DATA.CLASS_NAMES[cls]) labels.append(cfg.DATA.CLASS_NAMES[cls])
img = viz.draw_boxes(img, boxes, labels) img = viz.draw_boxes(img, boxes, labels)
if polygons is not None:
for p in polygons:
mask = polygons_to_mask(p, img.shape[0], img.shape[1])
img = draw_mask(img, mask)
return img return img
...@@ -102,6 +108,7 @@ def draw_mask(im, mask, alpha=0.5, color=None): ...@@ -102,6 +108,7 @@ def draw_mask(im, mask, alpha=0.5, color=None):
""" """
if color is None: if color is None:
color = PALETTE_RGB[np.random.choice(len(PALETTE_RGB))][::-1] color = PALETTE_RGB[np.random.choice(len(PALETTE_RGB))][::-1]
color = np.asarray(color, dtype=np.float32)
im = np.where(np.repeat((mask > 0)[:, :, None], 3, axis=2), im = np.where(np.repeat((mask > 0)[:, :, None], 3, axis=2),
im * (1 - alpha) + color * alpha, im) im * (1 - alpha) + color * alpha, im)
im = im.astype('uint8') im = im.astype('uint8')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment