Commit 29c81dd8 authored by Yuxin Wu's avatar Yuxin Wu

optionally include mask in train_dataflow

parent 1525e800
......@@ -4,10 +4,13 @@
import numpy as np
import cv2
from tensorpack.dataflow import RNGDataFlow
from tensorpack.dataflow.imgaug import transform
from tensorpack.utils import logger
import pycocotools.mask as cocomask
import config
......@@ -85,6 +88,22 @@ def point8_to_box(points):
return np.concatenate((minxy, maxxy), axis=1)
def segmentation_to_mask(polys, height, width):
"""
Convert polygons to binary masks.
Args:
polys: a list of nx2 float array
Returns:
a binary matrix of (height, width)
"""
polys = [p.flatten().tolist() for p in polys]
rles = cocomask.frPyObjects(polys, height, width)
rle = cocomask.merge(rles)
return cocomask.decode(rle)
def clip_boxes(boxes, shape):
"""
Args:
......
......@@ -20,7 +20,7 @@ from utils.generate_anchors import generate_anchors
from utils.box_ops import get_iou_callable
from common import (
DataFromListOfDict, CustomResize,
box_to_point8, point8_to_box)
box_to_point8, point8_to_box, segmentation_to_mask)
import config
......@@ -192,8 +192,13 @@ def get_rpn_anchor_input(im, boxes, klass, is_crowd):
return featuremap_labels, featuremap_boxes
def get_train_dataflow():
imgs = COCODetection.load_many(config.BASEDIR, config.TRAIN_DATASET)
def get_train_dataflow(add_mask=False):
"""
Return a training dataflow. Each datapoint is:
image, fm_labels, fm_boxes, gt_boxes, gt_class [, masks]
"""
imgs = COCODetection.load_many(
config.BASEDIR, config.TRAIN_DATASET, add_gt=True, add_mask=add_mask)
# Valid training images should have at least one fg box.
# But this filter shall not be applied for testing.
imgs = list(filter(lambda img: len(img['boxes']) > 0, imgs)) # log invalid training
......@@ -229,7 +234,28 @@ def get_train_dataflow():
log_once("Input {} is invalid for training: {}".format(fname, str(e)), 'warn')
return None
return im, fm_labels, fm_boxes, boxes, klass
ret = [im, fm_labels, fm_boxes, boxes, klass]
# masks
segmentation = img.get('segmentation', None)
if segmentation is not None:
segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]]
assert len(segmentation) == len(boxes)
# one image-sized binary mask per box
masks = []
for box, polys in zip(boxes, segmentation):
polys = [aug.augment_coords(p, params) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], im.shape[1]))
masks = np.asarray(masks, dtype='uint8')
ret.append(masks)
# from viz import draw_annotation, draw_mask
# viz = draw_annotation(im, boxes, klass)
# for mask in masks:
# viz = draw_mask(viz, mask)
# tpviz.interactive_imshow(viz)
return ret
ds = MapData(ds, preprocess)
return ds
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment