Commit 49e04ffa authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use dict as input

parent 787be08e
......@@ -4,7 +4,6 @@
import cv2
import numpy as np
import copy
import itertools
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
......@@ -282,11 +281,11 @@ def get_train_dataflow():
If MODE_MASK, gt_masks: (N, h, w)
"""
imgs = COCODetection.load_many(
roidbs = COCODetection.load_many(
cfg.DATA.BASEDIR, cfg.DATA.TRAIN, add_gt=True, add_mask=cfg.MODE_MASK)
"""
To train on your own data, change this to your loader.
Produce "imgs" as a list of dict, in the dict the following keys are needed for training:
Produce "roidbs" as a list of dict, in the dict the following keys are needed for training:
height, width: integer
file_name: str, full path to the image
boxes: numpy array of kx4 floats
......@@ -304,19 +303,19 @@ def get_train_dataflow():
# Valid training images should have at least one fg box.
# But this filter shall not be applied for testing.
num = len(imgs)
imgs = list(filter(lambda img: len(img['boxes'][img['is_crowd'] == 0]) > 0, imgs))
num = len(roidbs)
roidbs = list(filter(lambda img: len(img['boxes'][img['is_crowd'] == 0]) > 0, roidbs))
logger.info("Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
num - len(imgs), len(imgs)))
num - len(roidbs), len(roidbs)))
ds = DataFromList(imgs, shuffle=True)
ds = DataFromList(roidbs, shuffle=True)
aug = imgaug.AugmentorList(
[CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
imgaug.Flip(horiz=True)])
def preprocess(img):
fname, boxes, klass, is_crowd = img['file_name'], img['boxes'], img['class'], img['is_crowd']
def preprocess(roidb):
fname, boxes, klass, is_crowd = roidb['file_name'], roidb['boxes'], roidb['class'], roidb['is_crowd']
boxes = np.copy(boxes)
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
......@@ -331,29 +330,31 @@ def get_train_dataflow():
boxes = point8_to_box(points)
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
ret = {'image': im}
# rpn anchor:
try:
if cfg.MODE_FPN:
multilevel_anchor_inputs = get_multilevel_rpn_anchor_input(im, boxes, is_crowd)
anchor_inputs = itertools.chain.from_iterable(multilevel_anchor_inputs)
for i, (anchor_labels, anchor_boxes) in enumerate(multilevel_anchor_inputs):
ret['anchor_labels_lvl{}'.format(i + 2)] = anchor_labels
ret['anchor_boxes_lvl{}'.format(i + 2)] = anchor_boxes
else:
# anchor_labels, anchor_boxes
anchor_inputs = get_rpn_anchor_input(im, boxes, is_crowd)
assert len(anchor_inputs) == 2
ret['anchor_labels'], ret['anchor_boxes'] = get_rpn_anchor_input(im, boxes, is_crowd)
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
klass = klass[is_crowd == 0]
ret['gt_boxes'] = boxes
ret['gt_labels'] = klass
if not len(boxes):
raise MalformedData("No valid gt_boxes!")
except MalformedData as e:
log_once("Input {} is filtered for training: {}".format(fname, str(e)), 'warn')
return None
ret = [im] + list(anchor_inputs) + [boxes, klass]
if cfg.MODE_MASK:
# augmentation will modify the polys in-place
segmentation = copy.deepcopy(img['segmentation'])
segmentation = copy.deepcopy(roidb['segmentation'])
segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]]
assert len(segmentation) == len(boxes)
......@@ -364,7 +365,7 @@ def get_train_dataflow():
polys = [aug.augment_coords(p, params) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], im.shape[1]))
masks = np.asarray(masks, dtype='uint8') # values in {0, 1}
ret.append(masks)
ret['gt_masks'] = masks
# from viz import draw_annotation, draw_mask
# viz = draw_annotation(im, boxes, klass)
......@@ -386,13 +387,13 @@ def get_eval_dataflow(shard=0, num_shards=1):
Args:
shard, num_shards: to get subset of evaluation data
"""
imgs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False)
num_imgs = len(imgs)
roidbs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False)
num_imgs = len(roidbs)
img_per_shard = num_imgs // num_shards
img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs)
# no filter for training
ds = DataFromListOfDict(imgs[img_range[0]: img_range[1]], ['file_name', 'id'])
ds = DataFromListOfDict(roidbs[img_range[0]: img_range[1]], ['file_name', 'id'])
def f(fname):
im = cv2.imread(fname, cv2.IMREAD_COLOR)
......
......@@ -160,17 +160,14 @@ class ResNetC4Model(DetectionModel):
return ret
def build_graph(self, *inputs):
inputs = dict(zip(self.input_names, inputs))
is_training = get_current_tower_context().is_training
if cfg.MODE_MASK:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
else:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
image = self.preprocess(image) # 1CHW
image = self.preprocess(inputs['image']) # 1CHW
featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR)
anchors = RPNAnchors(get_all_anchors(), anchor_labels, anchor_boxes)
anchors = RPNAnchors(get_all_anchors(), inputs['anchor_labels'], inputs['anchor_boxes'])
anchors = anchors.narrow_to(featuremap)
image_shape2d = tf.shape(image)[2:] # h,w
......@@ -182,6 +179,7 @@ class ResNetC4Model(DetectionModel):
cfg.RPN.TRAIN_PRE_NMS_TOPK if is_training else cfg.RPN.TEST_PRE_NMS_TOPK,
cfg.RPN.TRAIN_POST_NMS_TOPK if is_training else cfg.RPN.TEST_POST_NMS_TOPK)
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels']
if is_training:
# sample proposal boxes in training
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
......@@ -224,7 +222,7 @@ class ResNetC4Model(DetectionModel):
'maskrcnn', fg_feature, cfg.DATA.NUM_CATEGORY, num_convs=0) # #fg x #cat x 14x14
target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks, 1),
tf.expand_dims(inputs['gt_masks'], 1),
fg_sampled_boxes,
fg_inds_wrt_gt, 14,
pad_border=False) # nfg x 1x14x14
......@@ -293,18 +291,18 @@ class ResNetFPNModel(DetectionModel):
anchors[i] = anchors[i].narrow_to(p23456[i])
def build_graph(self, *inputs):
inputs = dict(zip(self.input_names, inputs))
num_fpn_level = len(cfg.FPN.ANCHOR_STRIDES)
assert len(cfg.RPN.ANCHOR_SIZES) == num_fpn_level
is_training = get_current_tower_context().is_training
image = inputs[0]
input_anchors = inputs[1: 1 + 2 * num_fpn_level]
multilevel_anchors = [RPNAnchors(*args) for args in
zip(get_all_anchors_fpn(), input_anchors[0::2], input_anchors[1::2])]
gt_boxes, gt_labels = inputs[11], inputs[12]
if cfg.MODE_MASK:
gt_masks = inputs[-1]
image = self.preprocess(image) # 1CHW
all_anchors_fpn = get_all_anchors_fpn()
multilevel_anchors = [RPNAnchors(
all_anchors_fpn[i],
inputs['anchor_labels_lvl{}'.format(i + 2)],
inputs['anchor_boxes_lvl{}'.format(i + 2)]) for i in range(len(all_anchors_fpn))]
image = self.preprocess(inputs['image']) # 1CHW
image_shape2d = tf.shape(image)[2:] # h,w
c2345 = resnet_fpn_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK)
......@@ -321,6 +319,7 @@ class ResNetFPNModel(DetectionModel):
multilevel_anchors, multilevel_label_logits,
multilevel_box_logits, image_shape2d)
gt_boxes, gt_labels = inputs['gt_boxes'], inputs['gt_labels']
if is_training:
rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels)
......@@ -361,7 +360,7 @@ class ResNetFPNModel(DetectionModel):
'maskrcnn', roi_feature_maskrcnn, cfg.DATA.NUM_CATEGORY) # #fg x #cat x 28 x 28
target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks, 1),
tf.expand_dims(inputs['gt_masks'], 1),
fg_sampled_boxes,
fg_inds_wrt_gt, 28,
pad_border=False) # fg x 1x28x28
......
......@@ -41,6 +41,10 @@ def _replace_global_by_local(kwargs):
@contextmanager
def override_to_local_variable(enable=True):
"""
Returns:
a context where all variables will be created as local.
"""
if enable:
def custom_getter(getter, name, *args, **kwargs):
......@@ -55,7 +59,16 @@ def override_to_local_variable(enable=True):
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L192-L218
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
"""
Helper class to assign variables on the least loaded ps-device.
Usage:
.. code-block:: python
with tf.device(LeastLoadedDeviceSetter(...)):
...
"""
def __init__(self, worker_device, ps_devices):
"""
Args:
......
......@@ -46,6 +46,8 @@ def _make_feeds(placeholders, datapoint):
elif isinstance(datapoint, dict):
ret = {p: datapoint[p.op.name] for p in placeholders}
return ret
else:
raise TypeError("Got a datapoint of type {}!".format(type(datapoint)))
class PlaceholderInput(InputSource):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment