Commit e8b1a84a authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] pack gt masks to bits

parent de125c4e
......@@ -179,13 +179,15 @@ class TrainingDataPreprocessor:
# And produce one image-sized binary mask per box.
masks = []
width_height = np.asarray([width, height], dtype=np.float32)
gt_mask_width = int(np.ceil(im.shape[1] / 8.0) * 8) # pad to 8 in order to pack mask into bits
for polys in segmentation:
if not self.cfg.DATA.ABSOLUTE_COORD:
polys = [p * width_height for p in polys]
polys = [self.aug.augment_coords(p, params) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], im.shape[1]))
masks.append(segmentation_to_mask(polys, im.shape[0], gt_mask_width))
masks = np.asarray(masks, dtype='uint8') # values in {0, 1}
ret['gt_masks'] = masks
masks = np.packbits(masks, axis=-1)
ret['gt_masks_packed'] = masks
# from viz import draw_annotation, draw_mask
# viz = draw_annotation(im, boxes, klass)
......
......@@ -20,7 +20,7 @@ from .model_cascade import CascadeRCNNHead
from .model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses
from .model_frcnn import (
BoxProposals, FastRCNNHead, fastrcnn_outputs, fastrcnn_predictions, sample_fast_rcnn_targets)
from .model_mrcnn import maskrcnn_loss, maskrcnn_upXconv_head
from .model_mrcnn import maskrcnn_loss, maskrcnn_upXconv_head, unpackbits_masks
from .model_rpn import generate_rpn_proposals, rpn_head, rpn_losses
......@@ -62,6 +62,9 @@ class GeneralizedRCNN(ModelDesc):
def build_graph(self, *inputs):
inputs = dict(zip(self.input_names, inputs))
if "gt_masks_packed" in inputs:
gt_masks = tf.cast(unpackbits_masks(inputs.pop("gt_masks_packed")), tf.uint8, name="gt_masks")
inputs["gt_masks"] = gt_masks
image = self.preprocess(inputs['image']) # 1CHW
......@@ -91,8 +94,8 @@ class ResNetC4Model(GeneralizedRCNN):
tf.TensorSpec((None,), tf.int64, 'gt_labels')] # all > 0
if cfg.MODE_MASK:
ret.append(
tf.TensorSpec((None, None, None), tf.uint8, 'gt_masks')
) # NR_GT x height x width
tf.TensorSpec((None, None, None), tf.uint8, 'gt_masks_packed')
) # NR_GT x height x ceil(width/8), packed groundtruth masks
return ret
def backbone(self, image):
......@@ -202,8 +205,8 @@ class ResNetFPNModel(GeneralizedRCNN):
tf.TensorSpec((None,), tf.int64, 'gt_labels')]) # all > 0
if cfg.MODE_MASK:
ret.append(
tf.TensorSpec((None, None, None), tf.uint8, 'gt_masks')
) # NR_GT x height x width
tf.TensorSpec((None, None, None), tf.uint8, 'gt_masks_packed')
)
return ret
def slice_feature_and_anchors(self, p23456, anchors):
......
......@@ -85,3 +85,22 @@ def maskrcnn_up4conv_head(*args, **kwargs):
def maskrcnn_up4conv_gn_head(*args, **kwargs):
return maskrcnn_upXconv_head(*args, num_convs=4, norm='GN', **kwargs)
def unpackbits_masks(masks):
"""
Args:
masks (Tensor): uint8 Tensor of shape N, H, W. The last dimension is packed bits.
Returns:
masks (Tensor): bool Tensor of shape N, H, 8*W.
This is a reverse operation of `np.packbits`
"""
assert masks.dtype == tf.uint8, masks
bits = tf.constant((128, 64, 32, 16, 8, 4, 2, 1), dtype=tf.uint8)
unpacked = tf.bitwise.bitwise_and(tf.expand_dims(masks, -1), bits) > 0
unpacked = tf.reshape(
unpacked,
tf.concat([tf.shape(masks)[:-1], [-1]], axis=0))
return unpacked
......@@ -238,7 +238,7 @@ def start_proc_mask_signal(proc):
if sys.version_info < (3, 4) or mp.get_start_method() == 'fork':
log_once(
"Starting a process with 'fork' method is not safe and may consume unnecessary extra memory."
" Use 'forkserver' method (available after Py3.4) instead if you run into any issues. "
" Use 'forkserver/spawn' method (available after Py3.4) instead if you run into any issues. "
"See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods",
'warn') # noqa
p.start()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment