Commit 8f10b0f8 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use "spawn" for safer multiprocessing

parent 2902bfbe
...@@ -93,7 +93,10 @@ _C.DATA.NUM_CATEGORY = 80 # without the background class (e.g., 80 for COCO) ...@@ -93,7 +93,10 @@ _C.DATA.NUM_CATEGORY = 80 # without the background class (e.g., 80 for COCO)
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG". _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1] # whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1]
_C.DATA.ABSOLUTE_COORD = True _C.DATA.ABSOLUTE_COORD = True
_C.DATA.NUM_WORKERS = 5 # number of data loading workers. set to 0 to disable parallel data loading # Number of data loading workers.
# In case of horovod training, this is the number of workers per-GPU (so you may want to use a smaller number).
# Set to 0 to disable parallel data loading
_C.DATA.NUM_WORKERS = 10
# backbone ---------------------- # backbone ----------------------
_C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz _C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz
......
...@@ -54,33 +54,30 @@ def print_class_histogram(roidbs): ...@@ -54,33 +54,30 @@ def print_class_histogram(roidbs):
@memoized @memoized
def get_all_anchors(stride=None, sizes=None): def get_all_anchors(*, stride, sizes, ratios, max_size):
""" """
Get all anchors in the largest possible image, shifted, floatbox Get all anchors in the largest possible image, shifted, floatbox
Args: Args:
stride (int): the stride of anchors. stride (int): the stride of anchors.
sizes (tuple[int]): the sizes (sqrt area) of anchors sizes (tuple[int]): the sizes (sqrt area) of anchors
ratios (tuple[int]): the aspect ratios of anchors
max_size (int): maximum size of input image
Returns: Returns:
anchors: SxSxNUM_ANCHORx4, where S == ceil(MAX_SIZE/STRIDE), floatbox anchors: SxSxNUM_ANCHORx4, where S == ceil(MAX_SIZE/STRIDE), floatbox
The layout in the NUM_ANCHOR dim is NUM_RATIO x NUM_SIZE. The layout in the NUM_ANCHOR dim is NUM_RATIO x NUM_SIZE.
""" """
if stride is None:
stride = cfg.RPN.ANCHOR_STRIDE
if sizes is None:
sizes = cfg.RPN.ANCHOR_SIZES
# Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors # Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
# are centered on stride / 2, have (approximate) sqrt areas of the specified # are centered on stride / 2, have (approximate) sqrt areas of the specified
# sizes, and aspect ratios as given. # sizes, and aspect ratios as given.
cell_anchors = generate_anchors( cell_anchors = generate_anchors(
stride, stride,
scales=np.array(sizes, dtype=np.float) / stride, scales=np.array(sizes, dtype=np.float) / stride,
ratios=np.array(cfg.RPN.ANCHOR_RATIOS, dtype=np.float)) ratios=np.array(ratios, dtype=np.float))
# anchors are intbox here. # anchors are intbox here.
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride) # anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)
max_size = cfg.PREPROC.MAX_SIZE
field_size = int(np.ceil(max_size / stride)) field_size = int(np.ceil(max_size / stride))
shifts = np.arange(0, field_size) * stride shifts = np.arange(0, field_size) * stride
shift_x, shift_y = np.meshgrid(shifts, shifts) shift_x, shift_y = np.meshgrid(shifts, shifts)
...@@ -104,206 +101,33 @@ def get_all_anchors(stride=None, sizes=None): ...@@ -104,206 +101,33 @@ def get_all_anchors(stride=None, sizes=None):
@memoized @memoized
def get_all_anchors_fpn(strides=None, sizes=None): def get_all_anchors_fpn(*, strides, sizes, ratios, max_size):
""" """
Returns: Returns:
[anchors]: each anchors is a SxSx NUM_ANCHOR_RATIOS x4 array. [anchors]: each anchors is a SxSx NUM_ANCHOR_RATIOS x4 array.
""" """
if strides is None:
strides = cfg.FPN.ANCHOR_STRIDES
if sizes is None:
sizes = cfg.RPN.ANCHOR_SIZES
assert len(strides) == len(sizes) assert len(strides) == len(sizes)
foas = [] foas = []
for stride, size in zip(strides, sizes): for stride, size in zip(strides, sizes):
foa = get_all_anchors(stride=stride, sizes=(size,)) foa = get_all_anchors(stride=stride, sizes=(size,), ratios=ratios, max_size=max_size)
foas.append(foa) foas.append(foa)
return foas return foas
def get_anchor_labels(anchors, gt_boxes, crowd_boxes): class TrainingDataPreprocessor:
""" """
Label each anchor as fg/bg/ignore. The mapper to preprocess the input data for training.
Args:
anchors: Ax4 float
gt_boxes: Bx4 float, non-crowd
crowd_boxes: Cx4 float
Returns:
anchor_labels: (A,) int. Each element is {-1, 0, 1}
anchor_boxes: Ax4. Contains the target gt_box for each anchor when the anchor is fg.
"""
# This function will modify labels and return the filtered inds
def filter_box_label(labels, value, max_num):
curr_inds = np.where(labels == value)[0]
if len(curr_inds) > max_num:
disable_inds = np.random.choice(
curr_inds, size=(len(curr_inds) - max_num),
replace=False)
labels[disable_inds] = -1 # ignore them
curr_inds = np.where(labels == value)[0]
return curr_inds
NA, NB = len(anchors), len(gt_boxes)
assert NB > 0 # empty images should have been filtered already
box_ious = np_iou(anchors, gt_boxes) # NA x NB
ious_argmax_per_anchor = box_ious.argmax(axis=1) # NA,
ious_max_per_anchor = box_ious.max(axis=1)
ious_max_per_gt = np.amax(box_ious, axis=0, keepdims=True) # 1xNB
# for each gt, find all those anchors (including ties) that has the max ious with it
anchors_with_max_iou_per_gt = np.where(box_ious == ious_max_per_gt)[0]
# Setting NA labels: 1--fg 0--bg -1--ignore
anchor_labels = -np.ones((NA,), dtype='int32') # NA,
# the order of setting neg/pos labels matter
anchor_labels[anchors_with_max_iou_per_gt] = 1
anchor_labels[ious_max_per_anchor >= cfg.RPN.POSITIVE_ANCHOR_THRESH] = 1
anchor_labels[ious_max_per_anchor < cfg.RPN.NEGATIVE_ANCHOR_THRESH] = 0
# label all non-ignore candidate boxes which overlap crowd as ignore
if crowd_boxes.size > 0:
cand_inds = np.where(anchor_labels >= 0)[0]
cand_anchors = anchors[cand_inds]
ioas = np_ioa(crowd_boxes, cand_anchors)
overlap_with_crowd = cand_inds[ioas.max(axis=0) > cfg.RPN.CROWD_OVERLAP_THRESH]
anchor_labels[overlap_with_crowd] = -1
# Subsample fg labels: ignore some fg if fg is too many
target_num_fg = int(cfg.RPN.BATCH_PER_IM * cfg.RPN.FG_RATIO)
fg_inds = filter_box_label(anchor_labels, 1, target_num_fg)
# Keep an image even if there is no foreground anchors
# if len(fg_inds) == 0:
# raise MalformedData("No valid foreground for RPN!")
# Subsample bg labels. num_bg is not allowed to be too many
old_num_bg = np.sum(anchor_labels == 0)
if old_num_bg == 0:
# No valid bg in this image, skip.
raise MalformedData("No valid background for RPN!")
target_num_bg = cfg.RPN.BATCH_PER_IM - len(fg_inds)
filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values
# Set anchor boxes: the best gt_box for each fg anchor
anchor_boxes = np.zeros((NA, 4), dtype='float32')
fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :]
anchor_boxes[fg_inds, :] = fg_boxes
# assert len(fg_inds) + np.sum(anchor_labels == 0) == cfg.RPN.BATCH_PER_IM
return anchor_labels, anchor_boxes
def get_rpn_anchor_input(im, boxes, is_crowd):
"""
Args:
im: an image
boxes: nx4, floatbox, gt. shoudn't be changed
is_crowd: n,
Returns:
The anchor labels and target boxes for each pixel in the featuremap.
fm_labels: fHxfWxNA
fm_boxes: fHxfWxNAx4
NA will be NUM_ANCHOR_SIZES x NUM_ANCHOR_RATIOS
"""
boxes = boxes.copy()
all_anchors = np.copy(get_all_anchors())
# fHxfWxAx4 -> (-1, 4)
featuremap_anchors_flatten = all_anchors.reshape((-1, 4))
# only use anchors inside the image
inside_ind, inside_anchors = filter_boxes_inside_shape(featuremap_anchors_flatten, im.shape[:2])
# obtain anchor labels and their corresponding gt boxes
anchor_labels, anchor_gt_boxes = get_anchor_labels(inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1])
# Fill them back to original size: fHxfWx1, fHxfWx4
anchorH, anchorW = all_anchors.shape[:2]
featuremap_labels = -np.ones((anchorH * anchorW * cfg.RPN.NUM_ANCHOR, ), dtype='int32')
featuremap_labels[inside_ind] = anchor_labels
featuremap_labels = featuremap_labels.reshape((anchorH, anchorW, cfg.RPN.NUM_ANCHOR))
featuremap_boxes = np.zeros((anchorH * anchorW * cfg.RPN.NUM_ANCHOR, 4), dtype='float32')
featuremap_boxes[inside_ind, :] = anchor_gt_boxes
featuremap_boxes = featuremap_boxes.reshape((anchorH, anchorW, cfg.RPN.NUM_ANCHOR, 4))
return featuremap_labels, featuremap_boxes
def get_multilevel_rpn_anchor_input(im, boxes, is_crowd):
"""
Args:
im: an image
boxes: nx4, floatbox, gt. shoudn't be changed
is_crowd: n,
Returns:
[(fm_labels, fm_boxes)]: Returns a tuple for each FPN level.
Each tuple contains the anchor labels and target boxes for each pixel in the featuremap.
fm_labels: fHxfWx NUM_ANCHOR_RATIOS
fm_boxes: fHxfWx NUM_ANCHOR_RATIOS x4
"""
boxes = boxes.copy()
anchors_per_level = get_all_anchors_fpn()
flatten_anchors_per_level = [k.reshape((-1, 4)) for k in anchors_per_level]
all_anchors_flatten = np.concatenate(flatten_anchors_per_level, axis=0)
inside_ind, inside_anchors = filter_boxes_inside_shape(all_anchors_flatten, im.shape[:2])
anchor_labels, anchor_gt_boxes = get_anchor_labels(inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1])
# map back to all_anchors, then split to each level
num_all_anchors = all_anchors_flatten.shape[0]
all_labels = -np.ones((num_all_anchors, ), dtype='int32')
all_labels[inside_ind] = anchor_labels
all_boxes = np.zeros((num_all_anchors, 4), dtype='float32')
all_boxes[inside_ind] = anchor_gt_boxes
start = 0
multilevel_inputs = []
for level_anchor in anchors_per_level:
assert level_anchor.shape[2] == len(cfg.RPN.ANCHOR_RATIOS)
anchor_shape = level_anchor.shape[:3] # fHxfWxNUM_ANCHOR_RATIOS
num_anchor_this_level = np.prod(anchor_shape)
end = start + num_anchor_this_level
multilevel_inputs.append(
(all_labels[start: end].reshape(anchor_shape),
all_boxes[start: end, :].reshape(anchor_shape + (4,))
))
start = end
assert end == num_all_anchors, "{} != {}".format(end, num_all_anchors)
return multilevel_inputs
def get_train_dataflow():
"""
Return a training dataflow. Each datapoint consists of the following:
An image: (h, w, 3),
1 or more pairs of (anchor_labels, anchor_boxes):
anchor_labels: (h', w', NA)
anchor_boxes: (h', w', NA, 4)
gt_boxes: (N, 4) Since the mapping may run in other processes, we write a new class and
gt_labels: (N,) explicitly pass cfg to it, in the spirit of "explicitly pass resources to subprocess".
If MODE_MASK, gt_masks: (N, h, w)
""" """
def __init__(self, cfg):
self.cfg = cfg
self.aug = imgaug.AugmentorList(
[CustomResize(cfg.PREPROC.TRAIN_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
imgaug.Flip(horiz=True)])
roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN)) def __call__(self, roidb):
print_class_histogram(roidbs)
# Valid training images should have at least one fg box.
# But this filter shall not be applied for testing.
num = len(roidbs)
roidbs = list(filter(lambda img: len(img['boxes'][img['is_crowd'] == 0]) > 0, roidbs))
logger.info("Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
num - len(roidbs), len(roidbs)))
ds = DataFromList(roidbs, shuffle=True)
aug = imgaug.AugmentorList(
[CustomResize(cfg.PREPROC.TRAIN_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
imgaug.Flip(horiz=True)])
def preprocess(roidb):
fname, boxes, klass, is_crowd = roidb['file_name'], roidb['boxes'], roidb['class'], roidb['is_crowd'] fname, boxes, klass, is_crowd = roidb['file_name'], roidb['boxes'], roidb['class'], roidb['is_crowd']
boxes = np.copy(boxes) boxes = np.copy(boxes)
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
...@@ -313,27 +137,27 @@ def get_train_dataflow(): ...@@ -313,27 +137,27 @@ def get_train_dataflow():
# assume floatbox as input # assume floatbox as input
assert boxes.dtype == np.float32, "Loader has to return floating point boxes!" assert boxes.dtype == np.float32, "Loader has to return floating point boxes!"
if not cfg.DATA.ABSOLUTE_COORD: if not self.cfg.DATA.ABSOLUTE_COORD:
boxes[:, 0::2] *= width boxes[:, 0::2] *= width
boxes[:, 1::2] *= height boxes[:, 1::2] *= height
# augmentation: # augmentation:
im, params = aug.augment_return_params(im) im, params = self.aug.augment_return_params(im)
points = box_to_point8(boxes) points = box_to_point8(boxes)
points = aug.augment_coords(points, params) points = self.aug.augment_coords(points, params)
boxes = point8_to_box(points) boxes = point8_to_box(points)
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!" assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
ret = {'image': im} ret = {'image': im}
# Add rpn data to dataflow: # Add rpn data to dataflow:
try: try:
if cfg.MODE_FPN: if self.cfg.MODE_FPN:
multilevel_anchor_inputs = get_multilevel_rpn_anchor_input(im, boxes, is_crowd) multilevel_anchor_inputs = self.get_multilevel_rpn_anchor_input(im, boxes, is_crowd)
for i, (anchor_labels, anchor_boxes) in enumerate(multilevel_anchor_inputs): for i, (anchor_labels, anchor_boxes) in enumerate(multilevel_anchor_inputs):
ret['anchor_labels_lvl{}'.format(i + 2)] = anchor_labels ret['anchor_labels_lvl{}'.format(i + 2)] = anchor_labels
ret['anchor_boxes_lvl{}'.format(i + 2)] = anchor_boxes ret['anchor_boxes_lvl{}'.format(i + 2)] = anchor_boxes
else: else:
ret['anchor_labels'], ret['anchor_boxes'] = get_rpn_anchor_input(im, boxes, is_crowd) ret['anchor_labels'], ret['anchor_boxes'] = self.get_rpn_anchor_input(im, boxes, is_crowd)
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
klass = klass[is_crowd == 0] klass = klass[is_crowd == 0]
...@@ -345,7 +169,7 @@ def get_train_dataflow(): ...@@ -345,7 +169,7 @@ def get_train_dataflow():
log_once("Input {} is filtered for training: {}".format(fname, str(e)), 'warn') log_once("Input {} is filtered for training: {}".format(fname, str(e)), 'warn')
return None return None
if cfg.MODE_MASK: if self.cfg.MODE_MASK:
# augmentation will modify the polys in-place # augmentation will modify the polys in-place
segmentation = copy.deepcopy(roidb['segmentation']) segmentation = copy.deepcopy(roidb['segmentation'])
segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]] segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]]
...@@ -356,9 +180,9 @@ def get_train_dataflow(): ...@@ -356,9 +180,9 @@ def get_train_dataflow():
masks = [] masks = []
width_height = np.asarray([width, height], dtype=np.float32) width_height = np.asarray([width, height], dtype=np.float32)
for polys in segmentation: for polys in segmentation:
if not cfg.DATA.ABSOLUTE_COORD: if not self.cfg.DATA.ABSOLUTE_COORD:
polys = [p * width_height for p in polys] polys = [p * width_height for p in polys]
polys = [aug.augment_coords(p, params) for p in polys] polys = [self.aug.augment_coords(p, params) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], im.shape[1])) masks.append(segmentation_to_mask(polys, im.shape[0], im.shape[1]))
masks = np.asarray(masks, dtype='uint8') # values in {0, 1} masks = np.asarray(masks, dtype='uint8') # values in {0, 1}
ret['gt_masks'] = masks ret['gt_masks'] = masks
...@@ -370,6 +194,195 @@ def get_train_dataflow(): ...@@ -370,6 +194,195 @@ def get_train_dataflow():
# tpviz.interactive_imshow(viz) # tpviz.interactive_imshow(viz)
return ret return ret
def get_rpn_anchor_input(self, im, boxes, is_crowd):
"""
Args:
im: an image
boxes: nx4, floatbox, gt. shoudn't be changed
is_crowd: n,
Returns:
The anchor labels and target boxes for each pixel in the featuremap.
fm_labels: fHxfWxNA
fm_boxes: fHxfWxNAx4
NA will be NUM_ANCHOR_SIZES x NUM_ANCHOR_RATIOS
"""
boxes = boxes.copy()
all_anchors = np.copy(get_all_anchors(
stride=self.cfg.RPN.ANCHOR_STRIDE,
sizes=self.cfg.RPN.ANCHOR_SIZES,
ratios=self.cfg.RPN.ANCHOR_RATIOS,
max_size=self.cfg.PREPROC.MAX_SIZE))
# fHxfWxAx4 -> (-1, 4)
featuremap_anchors_flatten = all_anchors.reshape((-1, 4))
# only use anchors inside the image
inside_ind, inside_anchors = filter_boxes_inside_shape(featuremap_anchors_flatten, im.shape[:2])
# obtain anchor labels and their corresponding gt boxes
anchor_labels, anchor_gt_boxes = self.get_anchor_labels(
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1])
# Fill them back to original size: fHxfWx1, fHxfWx4
num_anchor = self.cfg.RPN.NUM_ANCHOR
anchorH, anchorW = all_anchors.shape[:2]
featuremap_labels = -np.ones((anchorH * anchorW * num_anchor, ), dtype='int32')
featuremap_labels[inside_ind] = anchor_labels
featuremap_labels = featuremap_labels.reshape((anchorH, anchorW, num_anchor))
featuremap_boxes = np.zeros((anchorH * anchorW * num_anchor, 4), dtype='float32')
featuremap_boxes[inside_ind, :] = anchor_gt_boxes
featuremap_boxes = featuremap_boxes.reshape((anchorH, anchorW, num_anchor, 4))
return featuremap_labels, featuremap_boxes
def get_multilevel_rpn_anchor_input(self, im, boxes, is_crowd):
"""
Args:
im: an image
boxes: nx4, floatbox, gt. shoudn't be changed
is_crowd: n,
Returns:
[(fm_labels, fm_boxes)]: Returns a tuple for each FPN level.
Each tuple contains the anchor labels and target boxes for each pixel in the featuremap.
fm_labels: fHxfWx NUM_ANCHOR_RATIOS
fm_boxes: fHxfWx NUM_ANCHOR_RATIOS x4
"""
boxes = boxes.copy()
anchors_per_level = get_all_anchors_fpn(
strides=self.cfg.FPN.ANCHOR_STRIDES,
sizes=self.cfg.RPN.ANCHOR_SIZES,
ratios=self.cfg.RPN.ANCHOR_RATIOS,
max_size=self.cfg.PREPROC.MAX_SIZE)
flatten_anchors_per_level = [k.reshape((-1, 4)) for k in anchors_per_level]
all_anchors_flatten = np.concatenate(flatten_anchors_per_level, axis=0)
inside_ind, inside_anchors = filter_boxes_inside_shape(all_anchors_flatten, im.shape[:2])
anchor_labels, anchor_gt_boxes = self.get_anchor_labels(
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1])
# map back to all_anchors, then split to each level
num_all_anchors = all_anchors_flatten.shape[0]
all_labels = -np.ones((num_all_anchors, ), dtype='int32')
all_labels[inside_ind] = anchor_labels
all_boxes = np.zeros((num_all_anchors, 4), dtype='float32')
all_boxes[inside_ind] = anchor_gt_boxes
start = 0
multilevel_inputs = []
for level_anchor in anchors_per_level:
assert level_anchor.shape[2] == len(self.cfg.RPN.ANCHOR_RATIOS)
anchor_shape = level_anchor.shape[:3] # fHxfWxNUM_ANCHOR_RATIOS
num_anchor_this_level = np.prod(anchor_shape)
end = start + num_anchor_this_level
multilevel_inputs.append(
(all_labels[start: end].reshape(anchor_shape),
all_boxes[start: end, :].reshape(anchor_shape + (4,))
))
start = end
assert end == num_all_anchors, "{} != {}".format(end, num_all_anchors)
return multilevel_inputs
def get_anchor_labels(self, anchors, gt_boxes, crowd_boxes):
"""
Label each anchor as fg/bg/ignore.
Args:
anchors: Ax4 float
gt_boxes: Bx4 float, non-crowd
crowd_boxes: Cx4 float
Returns:
anchor_labels: (A,) int. Each element is {-1, 0, 1}
anchor_boxes: Ax4. Contains the target gt_box for each anchor when the anchor is fg.
"""
# This function will modify labels and return the filtered inds
def filter_box_label(labels, value, max_num):
curr_inds = np.where(labels == value)[0]
if len(curr_inds) > max_num:
disable_inds = np.random.choice(
curr_inds, size=(len(curr_inds) - max_num),
replace=False)
labels[disable_inds] = -1 # ignore them
curr_inds = np.where(labels == value)[0]
return curr_inds
NA, NB = len(anchors), len(gt_boxes)
assert NB > 0 # empty images should have been filtered already
box_ious = np_iou(anchors, gt_boxes) # NA x NB
ious_argmax_per_anchor = box_ious.argmax(axis=1) # NA,
ious_max_per_anchor = box_ious.max(axis=1)
ious_max_per_gt = np.amax(box_ious, axis=0, keepdims=True) # 1xNB
# for each gt, find all those anchors (including ties) that has the max ious with it
anchors_with_max_iou_per_gt = np.where(box_ious == ious_max_per_gt)[0]
# Setting NA labels: 1--fg 0--bg -1--ignore
anchor_labels = -np.ones((NA,), dtype='int32') # NA,
# the order of setting neg/pos labels matter
anchor_labels[anchors_with_max_iou_per_gt] = 1
anchor_labels[ious_max_per_anchor >= self.cfg.RPN.POSITIVE_ANCHOR_THRESH] = 1
anchor_labels[ious_max_per_anchor < self.cfg.RPN.NEGATIVE_ANCHOR_THRESH] = 0
# label all non-ignore candidate boxes which overlap crowd as ignore
if crowd_boxes.size > 0:
cand_inds = np.where(anchor_labels >= 0)[0]
cand_anchors = anchors[cand_inds]
ioas = np_ioa(crowd_boxes, cand_anchors)
overlap_with_crowd = cand_inds[ioas.max(axis=0) > self.cfg.RPN.CROWD_OVERLAP_THRESH]
anchor_labels[overlap_with_crowd] = -1
# Subsample fg labels: ignore some fg if fg is too many
target_num_fg = int(self.cfg.RPN.BATCH_PER_IM * self.cfg.RPN.FG_RATIO)
fg_inds = filter_box_label(anchor_labels, 1, target_num_fg)
# Keep an image even if there is no foreground anchors
# if len(fg_inds) == 0:
# raise MalformedData("No valid foreground for RPN!")
# Subsample bg labels. num_bg is not allowed to be too many
old_num_bg = np.sum(anchor_labels == 0)
if old_num_bg == 0:
# No valid bg in this image, skip.
raise MalformedData("No valid background for RPN!")
target_num_bg = self.cfg.RPN.BATCH_PER_IM - len(fg_inds)
filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values
# Set anchor boxes: the best gt_box for each fg anchor
anchor_boxes = np.zeros((NA, 4), dtype='float32')
fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :]
anchor_boxes[fg_inds, :] = fg_boxes
# assert len(fg_inds) + np.sum(anchor_labels == 0) == self.cfg.RPN.BATCH_PER_IM
return anchor_labels, anchor_boxes
def get_train_dataflow():
"""
Return a training dataflow. Each datapoint consists of the following:
An image: (h, w, 3),
1 or more pairs of (anchor_labels, anchor_boxes):
anchor_labels: (h', w', NA)
anchor_boxes: (h', w', NA, 4)
gt_boxes: (N, 4)
gt_labels: (N,)
If MODE_MASK, gt_masks: (N, h, w)
"""
roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN))
print_class_histogram(roidbs)
# Valid training images should have at least one fg box.
# But this filter shall not be applied for testing.
num = len(roidbs)
roidbs = list(filter(lambda img: len(img['boxes'][img['is_crowd'] == 0]) > 0, roidbs))
logger.info("Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
num - len(roidbs), len(roidbs)))
ds = DataFromList(roidbs, shuffle=True)
preprocess = TrainingDataPreprocessor(cfg)
if cfg.DATA.NUM_WORKERS > 0: if cfg.DATA.NUM_WORKERS > 0:
if cfg.TRAINER == 'horovod': if cfg.TRAINER == 'horovod':
buffer_size = cfg.DATA.NUM_WORKERS * 10 # one dataflow for each process, therefore don't need large buffer buffer_size = cfg.DATA.NUM_WORKERS * 10 # one dataflow for each process, therefore don't need large buffer
......
...@@ -101,7 +101,11 @@ class ResNetC4Model(GeneralizedRCNN): ...@@ -101,7 +101,11 @@ class ResNetC4Model(GeneralizedRCNN):
def rpn(self, image, features, inputs): def rpn(self, image, features, inputs):
featuremap = features[0] featuremap = features[0]
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR)
anchors = RPNAnchors(get_all_anchors(), inputs['anchor_labels'], inputs['anchor_boxes']) anchors = RPNAnchors(
get_all_anchors(
stride=cfg.RPN.ANCHOR_STRIDE, sizes=cfg.RPN.ANCHOR_SIZES,
ratios=cfg.RPN.ANCHOR_RATIOS, max_size=cfg.PREPROC.MAX_SIZE),
inputs['anchor_labels'], inputs['anchor_boxes'])
anchors = anchors.narrow_to(featuremap) anchors = anchors.narrow_to(featuremap)
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
...@@ -216,7 +220,11 @@ class ResNetFPNModel(GeneralizedRCNN): ...@@ -216,7 +220,11 @@ class ResNetFPNModel(GeneralizedRCNN):
assert len(cfg.RPN.ANCHOR_SIZES) == len(cfg.FPN.ANCHOR_STRIDES) assert len(cfg.RPN.ANCHOR_SIZES) == len(cfg.FPN.ANCHOR_STRIDES)
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
all_anchors_fpn = get_all_anchors_fpn() all_anchors_fpn = get_all_anchors_fpn(
strides=cfg.FPN.ANCHOR_STRIDES,
sizes=cfg.RPN.ANCHOR_SIZES,
ratios=cfg.RPN.ANCHOR_RATIOS,
max_size=cfg.PREPROC.MAX_SIZE)
multilevel_anchors = [RPNAnchors( multilevel_anchors = [RPNAnchors(
all_anchors_fpn[i], all_anchors_fpn[i],
inputs['anchor_labels_lvl{}'.format(i + 2)], inputs['anchor_labels_lvl{}'.format(i + 2)],
......
...@@ -25,6 +25,8 @@ except ImportError: ...@@ -25,6 +25,8 @@ except ImportError:
if __name__ == '__main__': if __name__ == '__main__':
import multiprocessing as mp
mp.set_start_method('spawn') # safer behavior & memory saving
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--load', help='load a model to start training from. Can overwrite BACKBONE.WEIGHTS') parser.add_argument('--load', help='load a model to start training from. Can overwrite BACKBONE.WEIGHTS')
parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn') parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn')
......
...@@ -59,7 +59,7 @@ class GPUUtilizationTracker(Callback): ...@@ -59,7 +59,7 @@ class GPUUtilizationTracker(Callback):
self._stop_evt = mp.Event() self._stop_evt = mp.Event()
self._queue = mp.Queue() self._queue = mp.Queue()
self._proc = mp.Process(target=self.worker, args=( self._proc = mp.Process(target=self.worker, args=(
self._evt, self._queue, self._stop_evt)) self._evt, self._queue, self._stop_evt, self._devices))
ensure_proc_terminate(self._proc) ensure_proc_terminate(self._proc)
start_proc_mask_signal(self._proc) start_proc_mask_signal(self._proc)
...@@ -96,9 +96,14 @@ class GPUUtilizationTracker(Callback): ...@@ -96,9 +96,14 @@ class GPUUtilizationTracker(Callback):
self._evt.set() self._evt.set()
self._proc.terminate() self._proc.terminate()
def worker(self, evt, rst_queue, stop_evt): @staticmethod
def worker(evt, rst_queue, stop_evt, devices):
"""
Args:
devices (list[int])
"""
with NVMLContext() as ctx: with NVMLContext() as ctx:
devices = [ctx.device(i) for i in self._devices] devices = [ctx.device(i) for i in devices]
while True: while True:
try: try:
evt.wait() # start epoch evt.wait() # start epoch
...@@ -106,7 +111,7 @@ class GPUUtilizationTracker(Callback): ...@@ -106,7 +111,7 @@ class GPUUtilizationTracker(Callback):
if stop_evt.is_set(): # or on exit if stop_evt.is_set(): # or on exit
return return
stats = np.zeros((len(self._devices),), dtype='f4') stats = np.zeros((len(devices),), dtype='f4')
cnt = 0 cnt = 0
while True: while True:
time.sleep(1) time.sleep(1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment