Commit 37530d96 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] small change on data loading

parent d2309a1b
...@@ -12,7 +12,7 @@ with the support of: ...@@ -12,7 +12,7 @@ with the support of:
## Dependencies ## Dependencies
+ Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug); + Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug);
+ [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV. + [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/), OpenCV.
+ Pre-trained [ImageNet ResNet model](http://models.tensorpack.com/FasterRCNN/) + Pre-trained [ImageNet ResNet model](http://models.tensorpack.com/FasterRCNN/)
from tensorpack model zoo. Use the models with "-AlignPadding". from tensorpack model zoo. Use the models with "-AlignPadding".
+ COCO data. It needs to have the following directory structure: + COCO data. It needs to have the following directory structure:
......
...@@ -116,8 +116,9 @@ class COCODetection(object): ...@@ -116,8 +116,9 @@ class COCODetection(object):
Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection. Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection.
If add_mask is True, also add 'segmentation' in coco poly format. If add_mask is True, also add 'segmentation' in coco poly format.
""" """
ann_ids = self.coco.getAnnIds(imgIds=img['id'], iscrowd=None) # ann_ids = self.coco.getAnnIds(imgIds=img['id'])
objs = self.coco.loadAnns(ann_ids) # objs = self.coco.loadAnns(ann_ids)
objs = self.coco.imgToAnns[img['id']] # equivalent but faster than the above two lines
# clean-up boxes # clean-up boxes
valid_objs = [] valid_objs = []
......
...@@ -164,6 +164,7 @@ def finalize_configs(is_training): ...@@ -164,6 +164,7 @@ def finalize_configs(is_training):
Run some sanity checks, and populate some configs from others Run some sanity checks, and populate some configs from others
""" """
_C.DATA.NUM_CLASS = _C.DATA.NUM_CATEGORY + 1 # +1 background _C.DATA.NUM_CLASS = _C.DATA.NUM_CATEGORY + 1 # +1 background
_C.DATA.BASEDIR = os.path.expanduser(_C.DATA.BASEDIR)
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN': if _C.BACKBONE.NORM != 'FreezeBN':
......
...@@ -89,7 +89,7 @@ class DetectionModel(ModelDesc): ...@@ -89,7 +89,7 @@ class DetectionModel(ModelDesc):
with tf.name_scope('fg_sample_patch_viz'): with tf.name_scope('fg_sample_patch_viz'):
fg_sampled_patches = crop_and_resize( fg_sampled_patches = crop_and_resize(
image, fg_rcnn_boxes, image, fg_rcnn_boxes,
tf.zeros(tf.shape(fg_rcnn_boxes)[0], dtype=tf.int32), 300) tf.zeros([tf.shape(fg_rcnn_boxes)[0]], dtype=tf.int32), 300)
fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1]) fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1])
fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB
tf.summary.image('viz', fg_sampled_patches, max_outputs=30) tf.summary.image('viz', fg_sampled_patches, max_outputs=30)
...@@ -517,7 +517,8 @@ if __name__ == '__main__': ...@@ -517,7 +517,8 @@ if __name__ == '__main__':
logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky.") logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky.")
args = parser.parse_args() args = parser.parse_args()
cfg.update_args(args.config) if args.config:
cfg.update_args(args.config)
MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model() MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment