Commit 7ce02248 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] update on rpn data

parent 83695c0b
...@@ -73,7 +73,7 @@ def resnet_bottleneck(l, ch_out, stride): ...@@ -73,7 +73,7 @@ def resnet_bottleneck(l, ch_out, stride):
l, shortcut = l, l l, shortcut = l, l
l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU) l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU)
if stride == 2: if stride == 2:
l = tf.pad(l, [[0, 0], [0, 0], [0, 1], [0, 1]]) l = tf.pad(l, [[0, 0], [0, 0], [1, 0], [1, 0]])
l = Conv2D('conv2', l, ch_out, 3, strides=2, activation=BNReLU, padding='VALID') l = Conv2D('conv2', l, ch_out, 3, strides=2, activation=BNReLU, padding='VALID')
else: else:
l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU) l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU)
...@@ -95,9 +95,9 @@ def resnet_group(name, l, block_func, features, count, stride): ...@@ -95,9 +95,9 @@ def resnet_group(name, l, block_func, features, count, stride):
def resnet_c4_backbone(image, num_blocks, freeze_c2=True): def resnet_c4_backbone(image, num_blocks, freeze_c2=True):
assert len(num_blocks) == 3 assert len(num_blocks) == 3
with resnet_argscope(): with resnet_argscope():
l = tf.pad(image, [[0, 0], [0, 0], [2, 3], [2, 3]]) l = tf.pad(image, [[0, 0], [0, 0], [3, 2], [3, 2]])
l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID') l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], [0, 1], [0, 1]]) l = tf.pad(l, [[0, 0], [0, 0], [1, 0], [1, 0]])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID') l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1) c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
# TODO replace var by const to enable optimization # TODO replace var by const to enable optimization
...@@ -125,10 +125,10 @@ def resnet_fpn_backbone(image, num_blocks, freeze_c2=True): ...@@ -125,10 +125,10 @@ def resnet_fpn_backbone(image, num_blocks, freeze_c2=True):
with resnet_argscope(): with resnet_argscope():
chan = image.shape[1] chan = image.shape[1]
l = tf.pad(image, tf.stack( l = tf.pad(image, tf.stack(
[[0, 0], [0, 0], [2, 3 + pad_shape2d[0]], [2, 3 + pad_shape2d[1]]])) [[0, 0], [0, 0], [3, 2 + pad_shape2d[0]], [3, 2 + pad_shape2d[1]]]))
l.set_shape([None, chan, None, None]) l.set_shape([None, chan, None, None])
l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID') l = Conv2D('conv0', l, 64, 7, strides=2, activation=BNReLU, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], [0, 1], [0, 1]]) l = tf.pad(l, [[0, 0], [0, 0], [1, 0], [1, 0]])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID') l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1) c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
if freeze_c2: if freeze_c2:
......
...@@ -162,7 +162,9 @@ class COCODetection(object): ...@@ -162,7 +162,9 @@ class COCODetection(object):
img['class'] = cls # n, always >0 img['class'] = cls # n, always >0
img['is_crowd'] = is_crowd # n, img['is_crowd'] = is_crowd # n,
if add_mask: if add_mask:
img['segmentation'] = [obj['segmentation'] for obj in valid_objs] # also required to be float32
img['segmentation'] = [
obj['segmentation'].astype('float32') for obj in valid_objs]
def print_class_histogram(self, imgs): def print_class_histogram(self, imgs):
nr_class = len(COCOMeta.class_names) nr_class = len(COCOMeta.class_names)
......
...@@ -134,20 +134,21 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes): ...@@ -134,20 +134,21 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
anchor_labels[ious_max_per_anchor >= config.POSITIVE_ANCHOR_THRES] = 1 anchor_labels[ious_max_per_anchor >= config.POSITIVE_ANCHOR_THRES] = 1
anchor_labels[ious_max_per_anchor < config.NEGATIVE_ANCHOR_THRES] = 0 anchor_labels[ious_max_per_anchor < config.NEGATIVE_ANCHOR_THRES] = 0
# First label all non-ignore candidate boxes which overlap crowd as ignore # We can label all non-ignore candidate boxes which overlap crowd as ignore
if crowd_boxes.size > 0: # But detectron did not do this.
cand_inds = np.where(anchor_labels >= 0)[0] # if crowd_boxes.size > 0:
cand_anchors = anchors[cand_inds] # cand_inds = np.where(anchor_labels >= 0)[0]
ious = np_iou(cand_anchors, crowd_boxes) # cand_anchors = anchors[cand_inds]
overlap_with_crowd = cand_inds[ious.max(axis=1) > config.CROWD_OVERLAP_THRES] # ious = np_iou(cand_anchors, crowd_boxes)
anchor_labels[overlap_with_crowd] = -1 # overlap_with_crowd = cand_inds[ious.max(axis=1) > config.CROWD_OVERLAP_THRES]
# anchor_labels[overlap_with_crowd] = -1
# Subsample fg labels: ignore some fg if fg is too many # Subsample fg labels: ignore some fg if fg is too many
target_num_fg = int(config.RPN_BATCH_PER_IM * config.RPN_FG_RATIO) target_num_fg = int(config.RPN_BATCH_PER_IM * config.RPN_FG_RATIO)
fg_inds = filter_box_label(anchor_labels, 1, target_num_fg) fg_inds = filter_box_label(anchor_labels, 1, target_num_fg)
if len(fg_inds) == 0: # Keep an image even if there is no foreground anchors
raise MalformedData("No valid foreground for RPN!") # if len(fg_inds) == 0:
# Note that fg could be fewer than the target ratio # raise MalformedData("No valid foreground for RPN!")
# Subsample bg labels. num_bg is not allowed to be too many # Subsample bg labels. num_bg is not allowed to be too many
old_num_bg = np.sum(anchor_labels == 0) old_num_bg = np.sum(anchor_labels == 0)
...@@ -162,6 +163,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes): ...@@ -162,6 +163,7 @@ def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
anchor_boxes = np.zeros((NA, 4), dtype='float32') anchor_boxes = np.zeros((NA, 4), dtype='float32')
fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :] fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :]
anchor_boxes[fg_inds, :] = fg_boxes anchor_boxes[fg_inds, :] = fg_boxes
# assert len(fg_inds) + np.sum(anchor_labels == 0) == config.RPN_BATCH_PER_IM
return anchor_labels, anchor_boxes return anchor_labels, anchor_boxes
...@@ -291,6 +293,7 @@ def get_train_dataflow(): ...@@ -291,6 +293,7 @@ def get_train_dataflow():
def preprocess(img): def preprocess(img):
fname, boxes, klass, is_crowd = img['file_name'], img['boxes'], img['class'], img['is_crowd'] fname, boxes, klass, is_crowd = img['file_name'], img['boxes'], img['class'], img['is_crowd']
boxes = np.copy(boxes)
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
im = im.astype('float32') im = im.astype('float32')
...@@ -366,6 +369,7 @@ def get_eval_dataflow(): ...@@ -366,6 +369,7 @@ def get_eval_dataflow():
if __name__ == '__main__': if __name__ == '__main__':
import os import os
# import IPython as IP; IP.embed()
from tensorpack.dataflow import PrintData from tensorpack.dataflow import PrintData
config.BASEDIR = os.path.expanduser('~/data/coco') config.BASEDIR = os.path.expanduser('~/data/coco')
ds = get_train_dataflow() ds = get_train_dataflow()
......
...@@ -364,6 +364,15 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True): ...@@ -364,6 +364,15 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
return tf.concat([ny0, nx0, ny0 + nh, nx0 + nw], axis=1) return tf.concat([ny0, nx0, ny0 + nh, nx0 + nw], axis=1)
# Expand bbox to a minium size of 1
# boxes_x1y1, boxes_x2y2 = tf.split(boxes, 2, axis=1)
# boxes_wh = boxes_x2y2 - boxes_x1y1
# boxes_center = tf.reshape((boxes_x2y2 + boxes_x1y1) * 0.5, [-1, 2])
# boxes_newwh = tf.maximum(boxes_wh, 1.)
# boxes_x1y1new = boxes_center - boxes_newwh * 0.5
# boxes_x2y2new = boxes_center + boxes_newwh * 0.5
# boxes = tf.concat([boxes_x1y1new, boxes_x2y2new], axis=1)
image_shape = tf.shape(image)[2:] image_shape = tf.shape(image)[2:]
boxes = transform_fpcoor_for_tf(boxes, image_shape, [crop_size, crop_size]) boxes = transform_fpcoor_for_tf(boxes, image_shape, [crop_size, crop_size])
image = tf.transpose(image, [0, 2, 3, 1]) # 1hwc image = tf.transpose(image, [0, 2, 3, 1]) # 1hwc
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment