Commit 9229285e authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] clean-ups about masks & segs

parent fabb7e7e
......@@ -138,16 +138,16 @@ class COCODetection(object):
if add_mask:
segs = obj['segmentation']
if not isinstance(segs, list):
# TODO
assert obj['iscrowd'] == 1
obj['segmentation'] = None
else:
valid_segs = [p for p in segs if len(p) >= 6]
valid_segs = [np.asarray(p).reshape(-1, 2) for p in segs if len(p) >= 6]
if len(valid_segs) < len(segs):
log_once("Image {} has invalid polygons!".format(img['file_name']), 'warn')
obj['segmentation'] = valid_segs
rle = segmentation_to_rle(obj['segmentation'], height, width)
obj['mask_rle'] = rle
# rle = segmentation_to_rle(obj['segmentation'], height, width)
# obj['mask_rle'] = rle
# all geometrically-valid boxes are returned
boxes = np.asarray([obj['bbox'] for obj in valid_objs], dtype='float32') # (n, 4)
......@@ -161,8 +161,9 @@ class COCODetection(object):
img['class'] = cls # n, always >0
img['is_crowd'] = is_crowd # n,
if add_mask:
mask_rles = [obj.pop('mask_rle') for obj in valid_objs]
img['mask_rles'] = mask_rles # list, each is an RLE with full-image coordinate
# mask_rles = [obj.pop('mask_rle') for obj in valid_objs]
# img['mask_rle'] = mask_rles # list, each is an RLE with full-image coordinate
img['segmentation'] = [obj['segmentation'] for obj in valid_segs]
del objs
......@@ -184,7 +185,7 @@ class COCODetection(object):
logger.info("Ground-Truth Boxes:\n" + colored(table, 'cyan'))
@staticmethod
def load_many(basedir, names, add_gt=True):
def load_many(basedir, names, add_gt=True, add_mask=False):
"""
Load and merges several instance files together.
"""
......@@ -193,7 +194,7 @@ class COCODetection(object):
ret = []
for n in names:
coco = COCODetection(basedir, n)
ret.extend(coco.load(add_gt))
ret.extend(coco.load(add_gt, add_mask=add_mask))
return ret
......
......@@ -6,6 +6,7 @@ from six.moves import zip
import numpy as np
from tensorpack.utils import viz
from tensorpack.utils.palette import PALETTE_RGB
from coco import COCOMeta
from utils.box_ops import get_iou_callable
......@@ -77,3 +78,20 @@ def draw_final_outputs(img, results):
if all_boxes.shape[0] == 0:
return img
return viz.draw_boxes(img, all_boxes, all_tags)
def draw_mask(im, mask, alpha=0.5, color=None):
"""
Overlay a mask on top of the image.
Args:
im: a 3-channel uint8 image in BGR
mask: a binary 1-channel image of the same size
color: if None, will choose automatically
"""
if color is None:
color = PALETTE_RGB[np.random.choice(len(PALETTE_RGB))][::-1]
im = np.where(np.repeat((mask > 0)[:, :, None], 3, axis=2),
im * (1 - alpha) + color * alpha, im)
im = im.astype('uint8')
return im
......@@ -36,6 +36,14 @@ class Augmentor(object):
d, params = self._augment_return_params(d)
return d
def augment_return_params(self, d):
"""
augmented data
augmentaion params
"""
return self._augment_return_params(d)
def _augment_return_params(self, d):
"""
Augment the image and return both image and params
......@@ -93,6 +101,9 @@ class Augmentor(object):
class ImageAugmentor(Augmentor):
def augment_coords(self, coords, param):
return self._augment_coords(coords, param)
def _augment_coords(self, coords, param):
"""
Augment the coordinates given the param.
......
......@@ -357,7 +357,7 @@ def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
def draw_boxes(im, boxes, labels=None, color=None):
"""
Args:
im (np.ndarray): a BGR image. It will not be modified.
im (np.ndarray): a BGR image in range [0,255]. It will not be modified.
boxes (np.ndarray or list[BoxBase]): If an ndarray,
must be of shape Nx4 where the second dimension is [x1, y1, x2, y2].
labels: (list[str] or None)
......@@ -389,7 +389,7 @@ def draw_boxes(im, boxes, labels=None, color=None):
im = im.copy()
COLOR = (218, 218, 218) if color is None else color
COLOR_DIFF_WEIGHT = np.asarray((3, 4, 2), dtype='int32') # https://www.wikiwand.com/en/Color_difference
COLOR_CANDIDATES = PALETTE_RGB[[0, 1, 2, 3, 18, 113], :]
COLOR_CANDIDATES = PALETTE_RGB[:, ::-1]
if im.ndim == 2 or (im.ndim == 3 and im.shape[2] == 1):
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
for i in sorted_inds:
......@@ -446,7 +446,7 @@ if __name__ == '__main__':
img2 = cv2.resize(img, (300, 300))
viz = stack_patches([img, img2], 1, 2, pad=True, viz=True)
if True:
if False:
img = cv2.imread('cat.jpg')
boxes = np.asarray([
[10, 30, 200, 100],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment