Commit 8c8de86c authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] speed up COCO data loading

parent 9b3b5413
...@@ -108,30 +108,23 @@ class COCODetection(DatasetSplit): ...@@ -108,30 +108,23 @@ class COCODetection(DatasetSplit):
and (if add_gt is True) 'boxes', 'class', 'is_crowd', and optionally and (if add_gt is True) 'boxes', 'class', 'is_crowd', and optionally
'segmentation'. 'segmentation'.
""" """
if add_mask: with timed_operation('Load annotations for {}'.format(
assert add_gt
with timed_operation('Load Groundtruth Boxes for {}'.format(
os.path.basename(self.annotation_file))): os.path.basename(self.annotation_file))):
img_ids = self.coco.getImgIds() img_ids = self.coco.getImgIds()
img_ids.sort() img_ids.sort()
# list of dict, each has keys: height,width,id,file_name # list of dict, each has keys: height,width,id,file_name
imgs = self.coco.loadImgs(img_ids) imgs = self.coco.loadImgs(img_ids)
for img in tqdm.tqdm(imgs): for idx, img in enumerate(tqdm.tqdm(imgs)):
img['image_id'] = img.pop('id') img['image_id'] = img.pop('id')
self._use_absolute_file_name(img) img['file_name'] = os.path.join(self._imgdir, img['file_name'])
if idx == 0:
# make sure the directories are correctly set
assert os.path.isfile(img["file_name"]), img["file_name"]
if add_gt: if add_gt:
self._add_detection_gt(img, add_mask) self._add_detection_gt(img, add_mask)
return imgs return imgs
def _use_absolute_file_name(self, img):
"""
Change relative filename to abosolute file name.
"""
img['file_name'] = os.path.join(
self._imgdir, img['file_name'])
assert os.path.isfile(img['file_name']), img['file_name']
def _add_detection_gt(self, img, add_mask): def _add_detection_gt(self, img, add_mask):
""" """
Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection. Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection.
...@@ -147,57 +140,58 @@ class COCODetection(DatasetSplit): ...@@ -147,57 +140,58 @@ class COCODetection(DatasetSplit):
"Annotation ids in '{}' are not unique!".format(self.annotation_file) "Annotation ids in '{}' are not unique!".format(self.annotation_file)
# clean-up boxes # clean-up boxes
valid_objs = []
width = img.pop('width') width = img.pop('width')
height = img.pop('height') height = img.pop('height')
all_boxes = []
all_segm = []
all_cls = []
all_iscrowd = []
for objid, obj in enumerate(objs): for objid, obj in enumerate(objs):
if obj.get('ignore', 0) == 1: if obj.get('ignore', 0) == 1:
continue continue
x1, y1, w, h = obj['bbox'] x1, y1, w, h = list(map(float, obj['bbox']))
# bbox is originally in float # bbox is originally in float
# x1/y1 means upper-left corner and w/h means true w/h. This can be verified by segmentation pixels. # x1/y1 means upper-left corner and w/h means true w/h. This can be verified by segmentation pixels.
# But we do make an assumption here that (0.0, 0.0) is upper-left corner of the first pixel # But we do make an assumption here that (0.0, 0.0) is upper-left corner of the first pixel
x2, y2 = x1 + w, y1 + h
x1 = np.clip(float(x1), 0, width)
y1 = np.clip(float(y1), 0, height) # np.clip would be quite slow here
w = np.clip(float(x1 + w), 0, width) - x1 x1 = min(max(x1, 0), width)
h = np.clip(float(y1 + h), 0, height) - y1 x2 = min(max(x2, 0), width)
y1 = min(max(y1, 0), height)
y2 = min(max(y2, 0), height)
w, h = x2 - x1, y2 - y1
# Require non-zero seg area and more than 1x1 box size # Require non-zero seg area and more than 1x1 box size
if obj['area'] > 1 and w > 0 and h > 0 and w * h >= 4: if obj['area'] > 1 and w > 0 and h > 0 and w * h >= 4:
obj['bbox'] = [x1, y1, x1 + w, y1 + h] all_boxes.append([x1, y1, x2, y2])
valid_objs.append(obj) all_cls.append(self.COCO_id_to_category_id.get(obj['category_id'], obj['category_id']))
iscrowd = obj.get("iscrowd", 0)
all_iscrowd.append(iscrowd)
if add_mask: if add_mask:
segs = obj['segmentation'] segs = obj['segmentation']
if not isinstance(segs, list): if not isinstance(segs, list):
assert obj['iscrowd'] == 1 assert iscrowd == 1
obj['segmentation'] = None all_segm.append(None)
else: else:
valid_segs = [np.asarray(p).reshape(-1, 2).astype('float32') for p in segs if len(p) >= 6] valid_segs = [np.asarray(p).reshape(-1, 2).astype('float32') for p in segs if len(p) >= 6]
if len(valid_segs) == 0: if len(valid_segs) == 0:
logger.error("Object {} in image {} has no valid polygons!".format(objid, img['file_name'])) logger.error("Object {} in image {} has no valid polygons!".format(objid, img['file_name']))
elif len(valid_segs) < len(segs): elif len(valid_segs) < len(segs):
logger.warn("Object {} in image {} has invalid polygons!".format(objid, img['file_name'])) logger.warn("Object {} in image {} has invalid polygons!".format(objid, img['file_name']))
all_segm.append(valid_segs)
obj['segmentation'] = valid_segs
# all geometrically-valid boxes are returned # all geometrically-valid boxes are returned
boxes = np.asarray([obj['bbox'] for obj in valid_objs], dtype='float32') # (n, 4) img['boxes'] = np.asarray(all_boxes, dtype='float32') # (n, 4)
cls = np.asarray([ cls = np.asarray(all_cls, dtype='int32') # (n,)
self.COCO_id_to_category_id.get(obj['category_id'], obj['category_id'])
for obj in valid_objs], dtype='int32') # (n,)
is_crowd = np.asarray([obj['iscrowd'] for obj in valid_objs], dtype='int8')
# add the keys
img['boxes'] = boxes # nx4
if len(cls): if len(cls):
assert cls.min() > 0, "Category id in COCO format must > 0!" assert cls.min() > 0, "Category id in COCO format must > 0!"
img['class'] = cls # n, always >0 img['class'] = cls # n, always >0
img['is_crowd'] = is_crowd # n, img['is_crowd'] = np.asarray(all_iscrowd, dtype='int8') # n,
if add_mask: if add_mask:
# also required to be float32 # also required to be float32
img['segmentation'] = [ img['segmentation'] = all_segm
obj['segmentation'] for obj in valid_objs]
def training_roidbs(self): def training_roidbs(self):
return self.load(add_gt=True, add_mask=cfg.MODE_MASK) return self.load(add_gt=True, add_mask=cfg.MODE_MASK)
......
...@@ -265,7 +265,8 @@ class TFEventWriter(MonitorBase): ...@@ -265,7 +265,8 @@ class TFEventWriter(MonitorBase):
# Writing the graph is expensive (takes ~2min) when the graph is large. # Writing the graph is expensive (takes ~2min) when the graph is large.
# Therefore use a separate thread. It will then run in the # Therefore use a separate thread. It will then run in the
# background while TF is warming up in the first several iterations. # background while TF is warming up in the first several iterations.
self._write_graph_thread = threading.Thread(target=self._write_graph, daemon=True) self._write_graph_thread = threading.Thread(target=self._write_graph)
self._write_graph_thread.daemon = True
self._write_graph_thread.start() self._write_graph_thread.start()
@HIDE_DOC @HIDE_DOC
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment