Commit 10186aa1 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] crop_and_resize takes index as well

parent e1e2ee8d
......@@ -285,16 +285,15 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
@under_name_scope()
def crop_and_resize(image, boxes, size):
def crop_and_resize(image, boxes, box_ind, crop_size):
"""
Better-aligned version of tf.image.crop_and_resize,
following our definition of floating point boxes.
Args:
image: 1CHW
image: NCHW
boxes: nx4, x1y1x2y2
size (int):
box_ind: (n,)
crop_size (int):
Returns:
n,C,size,size
"""
......@@ -330,12 +329,11 @@ def crop_and_resize(image, boxes, size):
return tf.concat([ny0, nx0, ny0 + nh, nx0 + nw], axis=1)
image_shape = tf.shape(image)[2:]
boxes = transform_fpcoor_for_tf(boxes, image_shape, [size, size])
boxes = transform_fpcoor_for_tf(boxes, image_shape, [crop_size, crop_size])
image = tf.transpose(image, [0, 2, 3, 1]) # 1hwc
ret = tf.image.crop_and_resize(
image, boxes,
tf.zeros([tf.shape(boxes)[0]], dtype=tf.int32),
crop_size=[size, size])
image, boxes, box_ind,
crop_size=[crop_size, crop_size])
ret = tf.transpose(ret, [0, 3, 1, 2]) # ncss
return ret
......@@ -356,7 +354,10 @@ def roi_align(featuremap, boxes, output_shape):
boxes = tf.stop_gradient(boxes) # TODO
# sample 4 locations per roi bin
ret = crop_and_resize(featuremap, boxes, output_shape * 2)
ret = crop_and_resize(
featuremap, boxes,
tf.zeros([tf.shape(boxes)[0]], dtype=tf.int32),
output_shape * 2)
ret = tf.nn.avg_pool(ret, [1, 1, 2, 2], [1, 1, 2, 2], padding='SAME', data_format='NCHW')
return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment