Commit 6a0d33d1 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] bugfix on data loading

parent 94499e81
...@@ -144,7 +144,7 @@ class COCODetection(object): ...@@ -144,7 +144,7 @@ class COCODetection(object):
assert obj['iscrowd'] == 1 assert obj['iscrowd'] == 1
obj['segmentation'] = None obj['segmentation'] = None
else: else:
valid_segs = [np.asarray(p).reshape(-1, 2) for p in segs if len(p) >= 6] valid_segs = [np.asarray(p).reshape(-1, 2).astype('float32') for p in segs if len(p) >= 6]
if len(valid_segs) < len(segs): if len(valid_segs) < len(segs):
log_once("Image {} has invalid polygons!".format(img['file_name']), 'warn') log_once("Image {} has invalid polygons!".format(img['file_name']), 'warn')
...@@ -164,7 +164,7 @@ class COCODetection(object): ...@@ -164,7 +164,7 @@ class COCODetection(object):
if add_mask: if add_mask:
# also required to be float32 # also required to be float32
img['segmentation'] = [ img['segmentation'] = [
obj['segmentation'].astype('float32') for obj in valid_objs] obj['segmentation'] for obj in valid_objs]
def print_class_histogram(self, imgs): def print_class_histogram(self, imgs):
nr_class = len(COCOMeta.class_names) nr_class = len(COCOMeta.class_names)
......
...@@ -272,8 +272,12 @@ def get_train_dataflow(): ...@@ -272,8 +272,12 @@ def get_train_dataflow():
boxes: kx4 floats boxes: kx4 floats
class: k integers class: k integers
is_crowd: k booleans. Use k False if you don't know what it means. is_crowd: k booleans. Use k False if you don't know what it means.
segmentation: k numpy arrays. Each array is a polygon of shape Nx2. segmentation: k lists of numpy arrays (one for each box).
If your segmentation annotations are masks rather than polygons, Each list of numpy array corresponds to the mask for one instance.
Each numpy array in the list is a polygon of shape Nx2,
because one mask can be represented by N polygons.
If your segmentation annotations are originally masks rather than polygons,
either convert it, or the augmentation code below will need to be either convert it, or the augmentation code below will need to be
changed or skipped accordingly. changed or skipped accordingly.
""" """
...@@ -369,7 +373,6 @@ def get_eval_dataflow(): ...@@ -369,7 +373,6 @@ def get_eval_dataflow():
if __name__ == '__main__': if __name__ == '__main__':
import os import os
# import IPython as IP; IP.embed()
from tensorpack.dataflow import PrintData from tensorpack.dataflow import PrintData
config.BASEDIR = os.path.expanduser('~/data/coco') config.BASEDIR = os.path.expanduser('~/data/coco')
ds = get_train_dataflow() ds = get_train_dataflow()
......
...@@ -377,7 +377,7 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True): ...@@ -377,7 +377,7 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
boxes = transform_fpcoor_for_tf(boxes, image_shape, [crop_size, crop_size]) boxes = transform_fpcoor_for_tf(boxes, image_shape, [crop_size, crop_size])
image = tf.transpose(image, [0, 2, 3, 1]) # 1hwc image = tf.transpose(image, [0, 2, 3, 1]) # 1hwc
ret = tf.image.crop_and_resize( ret = tf.image.crop_and_resize(
image, boxes, box_ind, image, boxes, tf.to_int32(box_ind),
crop_size=[crop_size, crop_size]) crop_size=[crop_size, crop_size])
ret = tf.transpose(ret, [0, 3, 1, 2]) # ncss ret = tf.transpose(ret, [0, 3, 1, 2]) # ncss
return ret return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment