Commit 8c8de86c authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] speed up COCO data loading

parent 9b3b5413
......@@ -108,30 +108,23 @@ class COCODetection(DatasetSplit):
and (if add_gt is True) 'boxes', 'class', 'is_crowd', and optionally
'segmentation'.
"""
if add_mask:
assert add_gt
with timed_operation('Load Groundtruth Boxes for {}'.format(
with timed_operation('Load annotations for {}'.format(
os.path.basename(self.annotation_file))):
img_ids = self.coco.getImgIds()
img_ids.sort()
# list of dict, each has keys: height,width,id,file_name
imgs = self.coco.loadImgs(img_ids)
for img in tqdm.tqdm(imgs):
for idx, img in enumerate(tqdm.tqdm(imgs)):
img['image_id'] = img.pop('id')
self._use_absolute_file_name(img)
img['file_name'] = os.path.join(self._imgdir, img['file_name'])
if idx == 0:
# make sure the directories are correctly set
assert os.path.isfile(img["file_name"]), img["file_name"]
if add_gt:
self._add_detection_gt(img, add_mask)
return imgs
def _use_absolute_file_name(self, img):
"""
Change relative filename to abosolute file name.
"""
img['file_name'] = os.path.join(
self._imgdir, img['file_name'])
assert os.path.isfile(img['file_name']), img['file_name']
def _add_detection_gt(self, img, add_mask):
"""
Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection.
......@@ -147,57 +140,58 @@ class COCODetection(DatasetSplit):
"Annotation ids in '{}' are not unique!".format(self.annotation_file)
# clean-up boxes
valid_objs = []
width = img.pop('width')
height = img.pop('height')
all_boxes = []
all_segm = []
all_cls = []
all_iscrowd = []
for objid, obj in enumerate(objs):
if obj.get('ignore', 0) == 1:
continue
x1, y1, w, h = obj['bbox']
x1, y1, w, h = list(map(float, obj['bbox']))
# bbox is originally in float
# x1/y1 means upper-left corner and w/h means true w/h. This can be verified by segmentation pixels.
# But we do make an assumption here that (0.0, 0.0) is upper-left corner of the first pixel
x1 = np.clip(float(x1), 0, width)
y1 = np.clip(float(y1), 0, height)
w = np.clip(float(x1 + w), 0, width) - x1
h = np.clip(float(y1 + h), 0, height) - y1
x2, y2 = x1 + w, y1 + h
# np.clip would be quite slow here
x1 = min(max(x1, 0), width)
x2 = min(max(x2, 0), width)
y1 = min(max(y1, 0), height)
y2 = min(max(y2, 0), height)
w, h = x2 - x1, y2 - y1
# Require non-zero seg area and more than 1x1 box size
if obj['area'] > 1 and w > 0 and h > 0 and w * h >= 4:
obj['bbox'] = [x1, y1, x1 + w, y1 + h]
valid_objs.append(obj)
all_boxes.append([x1, y1, x2, y2])
all_cls.append(self.COCO_id_to_category_id.get(obj['category_id'], obj['category_id']))
iscrowd = obj.get("iscrowd", 0)
all_iscrowd.append(iscrowd)
if add_mask:
segs = obj['segmentation']
if not isinstance(segs, list):
assert obj['iscrowd'] == 1
obj['segmentation'] = None
assert iscrowd == 1
all_segm.append(None)
else:
valid_segs = [np.asarray(p).reshape(-1, 2).astype('float32') for p in segs if len(p) >= 6]
if len(valid_segs) == 0:
logger.error("Object {} in image {} has no valid polygons!".format(objid, img['file_name']))
elif len(valid_segs) < len(segs):
logger.warn("Object {} in image {} has invalid polygons!".format(objid, img['file_name']))
obj['segmentation'] = valid_segs
all_segm.append(valid_segs)
# all geometrically-valid boxes are returned
boxes = np.asarray([obj['bbox'] for obj in valid_objs], dtype='float32') # (n, 4)
cls = np.asarray([
self.COCO_id_to_category_id.get(obj['category_id'], obj['category_id'])
for obj in valid_objs], dtype='int32') # (n,)
is_crowd = np.asarray([obj['iscrowd'] for obj in valid_objs], dtype='int8')
# add the keys
img['boxes'] = boxes # nx4
img['boxes'] = np.asarray(all_boxes, dtype='float32') # (n, 4)
cls = np.asarray(all_cls, dtype='int32') # (n,)
if len(cls):
assert cls.min() > 0, "Category id in COCO format must > 0!"
img['class'] = cls # n, always >0
img['is_crowd'] = is_crowd # n,
img['is_crowd'] = np.asarray(all_iscrowd, dtype='int8') # n,
if add_mask:
# also required to be float32
img['segmentation'] = [
obj['segmentation'] for obj in valid_objs]
img['segmentation'] = all_segm
def training_roidbs(self):
return self.load(add_gt=True, add_mask=cfg.MODE_MASK)
......
......@@ -265,7 +265,8 @@ class TFEventWriter(MonitorBase):
# Writing the graph is expensive (takes ~2min) when the graph is large.
# Therefore use a separate thread. It will then run in the
# background while TF is warming up in the first several iterations.
self._write_graph_thread = threading.Thread(target=self._write_graph, daemon=True)
self._write_graph_thread = threading.Thread(target=self._write_graph)
self._write_graph_thread.daemon = True
self._write_graph_thread.start()
@HIDE_DOC
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment