Commit 49e04ffa authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use dict as input

parent 787be08e
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import cv2 import cv2
import numpy as np import numpy as np
import copy import copy
import itertools
from tensorpack.utils.argtools import memoized, log_once from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import ( from tensorpack.dataflow import (
...@@ -282,11 +281,11 @@ def get_train_dataflow(): ...@@ -282,11 +281,11 @@ def get_train_dataflow():
If MODE_MASK, gt_masks: (N, h, w) If MODE_MASK, gt_masks: (N, h, w)
""" """
imgs = COCODetection.load_many( roidbs = COCODetection.load_many(
cfg.DATA.BASEDIR, cfg.DATA.TRAIN, add_gt=True, add_mask=cfg.MODE_MASK) cfg.DATA.BASEDIR, cfg.DATA.TRAIN, add_gt=True, add_mask=cfg.MODE_MASK)
""" """
To train on your own data, change this to your loader. To train on your own data, change this to your loader.
Produce "imgs" as a list of dict, in the dict the following keys are needed for training: Produce "roidbs" as a list of dict, in the dict the following keys are needed for training:
height, width: integer height, width: integer
file_name: str, full path to the image file_name: str, full path to the image
boxes: numpy array of kx4 floats boxes: numpy array of kx4 floats
...@@ -304,19 +303,19 @@ def get_train_dataflow(): ...@@ -304,19 +303,19 @@ def get_train_dataflow():
# Valid training images should have at least one fg box. # Valid training images should have at least one fg box.
# But this filter shall not be applied for testing. # But this filter shall not be applied for testing.
num = len(imgs) num = len(roidbs)
imgs = list(filter(lambda img: len(img['boxes'][img['is_crowd'] == 0]) > 0, imgs)) roidbs = list(filter(lambda img: len(img['boxes'][img['is_crowd'] == 0]) > 0, roidbs))
logger.info("Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format( logger.info("Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
num - len(imgs), len(imgs))) num - len(roidbs), len(roidbs)))
ds = DataFromList(imgs, shuffle=True) ds = DataFromList(roidbs, shuffle=True)
aug = imgaug.AugmentorList( aug = imgaug.AugmentorList(
[CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE), [CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
imgaug.Flip(horiz=True)]) imgaug.Flip(horiz=True)])
def preprocess(img): def preprocess(roidb):
fname, boxes, klass, is_crowd = img['file_name'], img['boxes'], img['class'], img['is_crowd'] fname, boxes, klass, is_crowd = roidb['file_name'], roidb['boxes'], roidb['class'], roidb['is_crowd']
boxes = np.copy(boxes) boxes = np.copy(boxes)
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
...@@ -331,29 +330,31 @@ def get_train_dataflow(): ...@@ -331,29 +330,31 @@ def get_train_dataflow():
boxes = point8_to_box(points) boxes = point8_to_box(points)
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!" assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
ret = {'image': im}
# rpn anchor: # rpn anchor:
try: try:
if cfg.MODE_FPN: if cfg.MODE_FPN:
multilevel_anchor_inputs = get_multilevel_rpn_anchor_input(im, boxes, is_crowd) multilevel_anchor_inputs = get_multilevel_rpn_anchor_input(im, boxes, is_crowd)
anchor_inputs = itertools.chain.from_iterable(multilevel_anchor_inputs) for i, (anchor_labels, anchor_boxes) in enumerate(multilevel_anchor_inputs):
ret['anchor_labels_lvl{}'.format(i + 2)] = anchor_labels
ret['anchor_boxes_lvl{}'.format(i + 2)] = anchor_boxes
else: else:
# anchor_labels, anchor_boxes # anchor_labels, anchor_boxes
anchor_inputs = get_rpn_anchor_input(im, boxes, is_crowd) ret['anchor_labels'], ret['anchor_boxes'] = get_rpn_anchor_input(im, boxes, is_crowd)
assert len(anchor_inputs) == 2
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
klass = klass[is_crowd == 0] klass = klass[is_crowd == 0]
ret['gt_boxes'] = boxes
ret['gt_labels'] = klass
if not len(boxes): if not len(boxes):
raise MalformedData("No valid gt_boxes!") raise MalformedData("No valid gt_boxes!")
except MalformedData as e: except MalformedData as e:
log_once("Input {} is filtered for training: {}".format(fname, str(e)), 'warn') log_once("Input {} is filtered for training: {}".format(fname, str(e)), 'warn')
return None return None
ret = [im] + list(anchor_inputs) + [boxes, klass]
if cfg.MODE_MASK: if cfg.MODE_MASK:
# augmentation will modify the polys in-place # augmentation will modify the polys in-place
segmentation = copy.deepcopy(img['segmentation']) segmentation = copy.deepcopy(roidb['segmentation'])
segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]] segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]]
assert len(segmentation) == len(boxes) assert len(segmentation) == len(boxes)
...@@ -364,7 +365,7 @@ def get_train_dataflow(): ...@@ -364,7 +365,7 @@ def get_train_dataflow():
polys = [aug.augment_coords(p, params) for p in polys] polys = [aug.augment_coords(p, params) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], im.shape[1])) masks.append(segmentation_to_mask(polys, im.shape[0], im.shape[1]))
masks = np.asarray(masks, dtype='uint8') # values in {0, 1} masks = np.asarray(masks, dtype='uint8') # values in {0, 1}
ret.append(masks) ret['gt_masks'] = masks
# from viz import draw_annotation, draw_mask # from viz import draw_annotation, draw_mask
# viz = draw_annotation(im, boxes, klass) # viz = draw_annotation(im, boxes, klass)
...@@ -386,13 +387,13 @@ def get_eval_dataflow(shard=0, num_shards=1): ...@@ -386,13 +387,13 @@ def get_eval_dataflow(shard=0, num_shards=1):
Args: Args:
shard, num_shards: to get subset of evaluation data shard, num_shards: to get subset of evaluation data
""" """
imgs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False) roidbs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False)
num_imgs = len(imgs) num_imgs = len(roidbs)
img_per_shard = num_imgs // num_shards img_per_shard = num_imgs // num_shards
img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs) img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs)
# no filter for training # no filter for training
ds = DataFromListOfDict(imgs[img_range[0]: img_range[1]], ['file_name', 'id']) ds = DataFromListOfDict(roidbs[img_range[0]: img_range[1]], ['file_name', 'id'])
def f(fname): def f(fname):
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
......
...@@ -160,17 +160,14 @@ class ResNetC4Model(DetectionModel): ...@@ -160,17 +160,14 @@ class ResNetC4Model(DetectionModel):
return ret return ret
def build_graph(self, *inputs): def build_graph(self, *inputs):
inputs = dict(zip(self.input_names, inputs))
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
if cfg.MODE_MASK: image = self.preprocess(inputs['image']) # 1CHW
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
else:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
image = self.preprocess(image) # 1CHW
featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3]) featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR)
anchors = RPNAnchors(get_all_anchors(), anchor_labels, anchor_boxes) anchors = RPNAnchors(get_all_anchors(), inputs['anchor_labels'], inputs['anchor_boxes'])
anchors = anchors.narrow_to(featuremap) anchors = anchors.narrow_to(featuremap)
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
...@@ -182,6 +179,7 @@ class ResNetC4Model(DetectionModel): ...@@ -182,6 +179,7 @@ class ResNetC4Model(DetectionModel):
cfg.RPN.TRAIN_PRE_NMS_TOPK if is_training else cfg.RPN.TEST_PRE_NMS_TOPK, cfg.RPN.TRAIN_PRE_NMS_TOPK if is_training else cfg.RPN.TEST_PRE_NMS_TOPK,
cfg.RPN.TRAIN_POST_NMS_TOPK if is_training else cfg.RPN.TEST_POST_NMS_TOPK) cfg.RPN.TRAIN_POST_NMS_TOPK if is_training else cfg.RPN.TEST_POST_NMS_TOPK)
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels']
if is_training: if is_training:
# sample proposal boxes in training # sample proposal boxes in training
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
...@@ -224,7 +222,7 @@ class ResNetC4Model(DetectionModel): ...@@ -224,7 +222,7 @@ class ResNetC4Model(DetectionModel):
'maskrcnn', fg_feature, cfg.DATA.NUM_CATEGORY, num_convs=0) # #fg x #cat x 14x14 'maskrcnn', fg_feature, cfg.DATA.NUM_CATEGORY, num_convs=0) # #fg x #cat x 14x14
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks, 1), tf.expand_dims(inputs['gt_masks'], 1),
fg_sampled_boxes, fg_sampled_boxes,
fg_inds_wrt_gt, 14, fg_inds_wrt_gt, 14,
pad_border=False) # nfg x 1x14x14 pad_border=False) # nfg x 1x14x14
...@@ -293,18 +291,18 @@ class ResNetFPNModel(DetectionModel): ...@@ -293,18 +291,18 @@ class ResNetFPNModel(DetectionModel):
anchors[i] = anchors[i].narrow_to(p23456[i]) anchors[i] = anchors[i].narrow_to(p23456[i])
def build_graph(self, *inputs): def build_graph(self, *inputs):
inputs = dict(zip(self.input_names, inputs))
num_fpn_level = len(cfg.FPN.ANCHOR_STRIDES) num_fpn_level = len(cfg.FPN.ANCHOR_STRIDES)
assert len(cfg.RPN.ANCHOR_SIZES) == num_fpn_level assert len(cfg.RPN.ANCHOR_SIZES) == num_fpn_level
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
image = inputs[0]
input_anchors = inputs[1: 1 + 2 * num_fpn_level]
multilevel_anchors = [RPNAnchors(*args) for args in
zip(get_all_anchors_fpn(), input_anchors[0::2], input_anchors[1::2])]
gt_boxes, gt_labels = inputs[11], inputs[12]
if cfg.MODE_MASK:
gt_masks = inputs[-1]
image = self.preprocess(image) # 1CHW all_anchors_fpn = get_all_anchors_fpn()
multilevel_anchors = [RPNAnchors(
all_anchors_fpn[i],
inputs['anchor_labels_lvl{}'.format(i + 2)],
inputs['anchor_boxes_lvl{}'.format(i + 2)]) for i in range(len(all_anchors_fpn))]
image = self.preprocess(inputs['image']) # 1CHW
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
c2345 = resnet_fpn_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK) c2345 = resnet_fpn_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK)
...@@ -321,6 +319,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -321,6 +319,7 @@ class ResNetFPNModel(DetectionModel):
multilevel_anchors, multilevel_label_logits, multilevel_anchors, multilevel_label_logits,
multilevel_box_logits, image_shape2d) multilevel_box_logits, image_shape2d)
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels']
if is_training: if is_training:
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels) proposal_boxes, gt_boxes, gt_labels)
...@@ -361,7 +360,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -361,7 +360,7 @@ class ResNetFPNModel(DetectionModel):
'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY) # #fg x #cat x 28 x 28 'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY) # #fg x #cat x 28 x 28
target_masks_for_fg = crop_and_resize( target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks, 1), tf.expand_dims(inputs['gt_masks'], 1),
fg_sampled_boxes, fg_sampled_boxes,
fg_inds_wrt_gt, 28, fg_inds_wrt_gt, 28,
pad_border=False) # fg x 1x28x28 pad_border=False) # fg x 1x28x28
......
...@@ -41,6 +41,10 @@ def _replace_global_by_local(kwargs): ...@@ -41,6 +41,10 @@ def _replace_global_by_local(kwargs):
@contextmanager @contextmanager
def override_to_local_variable(enable=True): def override_to_local_variable(enable=True):
"""
Returns:
a context where all variables will be created as local.
"""
if enable: if enable:
def custom_getter(getter, name, *args, **kwargs): def custom_getter(getter, name, *args, **kwargs):
...@@ -55,7 +59,16 @@ def override_to_local_variable(enable=True): ...@@ -55,7 +59,16 @@ def override_to_local_variable(enable=True):
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L192-L218 # https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L192-L218
class LeastLoadedDeviceSetter(object): class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device.""" """
Helper class to assign variables on the least loaded ps-device.
Usage:
.. code-block:: python
with tf.device(LeastLoadedDeviceSetter(...)):
...
"""
def __init__(self, worker_device, ps_devices): def __init__(self, worker_device, ps_devices):
""" """
Args: Args:
......
...@@ -46,6 +46,8 @@ def _make_feeds(placeholders, datapoint): ...@@ -46,6 +46,8 @@ def _make_feeds(placeholders, datapoint):
elif isinstance(datapoint, dict): elif isinstance(datapoint, dict):
ret = {p: datapoint[p.op.name] for p in placeholders} ret = {p: datapoint[p.op.name] for p in placeholders}
return ret return ret
else:
raise TypeError("Got a datapoint of type {}!".format(type(datapoint)))
class PlaceholderInput(InputSource): class PlaceholderInput(InputSource):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment