Commit 42c9b8a7 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] clean-up segmentation code

parent 671b64cc
......@@ -34,7 +34,7 @@ Otherwise, you probably need different hyperparameters for the same performance.
To predict on an image (and show output in a window):
```
./train.py --predict input.jpg
./train.py --predict input.jpg --load /path/to/model
```
To evaluate the performance (pretrained models can be downloaded in [model zoo](https://drive.google.com/open?id=1J0xuDAuyOWiuJRm2LfGoz5PUv9_dKuxq):
......
......@@ -15,7 +15,6 @@ from tensorpack.utils.timer import timed_operation
from tensorpack.utils.argtools import log_once
from pycocotools.coco import COCO
import pycocotools.mask as cocomask
__all__ = ['COCODetection', 'COCOMeta']
......@@ -146,8 +145,6 @@ class COCODetection(object):
log_once("Image {} has invalid polygons!".format(img['file_name']), 'warn')
obj['segmentation'] = valid_segs
# rle = segmentation_to_rle(obj['segmentation'], height, width)
# obj['mask_rle'] = rle
# all geometrically-valid boxes are returned
boxes = np.asarray([obj['bbox'] for obj in valid_objs], dtype='float32') # (n, 4)
......@@ -161,9 +158,7 @@ class COCODetection(object):
img['class'] = cls # n, always >0
img['is_crowd'] = is_crowd # n,
if add_mask:
# mask_rles = [obj.pop('mask_rle') for obj in valid_objs]
# img['mask_rle'] = mask_rles # list, each is an RLE with full-image coordinate
img['segmentation'] = [obj['segmentation'] for obj in valid_segs]
img['segmentation'] = [obj['segmentation'] for obj in valid_objs]
del objs
......@@ -198,22 +193,6 @@ class COCODetection(object):
return ret
def segmentation_to_rle(segm, height, width):
if isinstance(segm, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = cocomask.frPyObjects(segm, height, width)
rle = cocomask.merge(rles)
elif isinstance(segm['counts'], list):
# uncompressed RLE
rle = cocomask.frPyObjects(segm, height, width)
else:
print("WTF?")
import IPython as IP
IP.embed()
return rle
if __name__ == '__main__':
c = COCODetection('/home/wyx/data/coco', 'train2014')
gt_boxes = c.load(add_gt=True, add_mask=True)
......
......@@ -10,8 +10,8 @@ import logging
from tensorpack.utils import logger
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
ProxyDataFlow, MapData, imgaug, TestDataSpeed,
MapDataComponent)
MapData, imgaug, TestDataSpeed,
MapDataComponent, DataFromList)
import tensorpack.utils.viz as tpviz
from tensorpack.utils.viz import interactive_imshow
......@@ -198,17 +198,14 @@ def get_train_dataflow():
# But this filter shall not be applied for testing.
imgs = list(filter(lambda img: len(img['boxes']) > 0, imgs)) # log invalid training
ds = DataFromListOfDict(
imgs,
['file_name', 'boxes', 'class', 'is_crowd'], # we need this four keys only
shuffle=True)
ds = DataFromList(imgs, shuffle=True)
aug = imgaug.AugmentorList(
[CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE),
imgaug.Flip(horiz=True)])
def preprocess(dp):
fname, boxes, klass, is_crowd = dp
def preprocess(img):
fname, boxes, klass, is_crowd = img['file_name'], img['boxes'], img['class'], img['is_crowd']
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
im = im.astype('float32')
......@@ -252,6 +249,8 @@ def get_eval_dataflow():
if __name__ == '__main__':
config.BASEDIR = '/home/wyx/data/coco'
config.TRAIN_DATASET = ['train2014']
from tensorpack.dataflow import PrintData
ds = get_train_dataflow()
ds = PrintData(ds, 100)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment