Commit 8c9969f8 authored by Yuxin Wu's avatar Yuxin Wu

[STN] Better image sample by gather_nd

parent 100ae5a0
......@@ -10,61 +10,52 @@ from ._test import TestModel
__all__ = ['ImageSample']
# XXX TODO ugly.
# really need to fix this after tensorflow supports advanced indexing
# See github:tensorflow#418,#206
def sample(img, coords):
"""
:param img: bxhxwxc
:param coords: bxh2xw2x2 (y, x) floating point (but is actually holding integer)
:return: bxh2xw2xc image
Args:
img: bxhxwxc
coords: bxh2xw2x2. each coordinate is (y, x) integer.
Out of boundary coordinates will be clipped.
Return:
bxh2xw2xc image
"""
orig_coords = tf.cast(coords, tf.int32)
shape = img.get_shape().as_list()[1:]
shape2 = coords.get_shape().as_list()[1:3]
coords = tf.to_int32(coords)
shape = img.get_shape().as_list()[1:] # h, w, c
batch = tf.shape(img)[0]
shape2 = coords.get_shape().as_list()[1:3] # h2, w2
assert None not in shape2, coords.get_shape()
max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.int32)
# clip_by_value actually supports broadcasting
coords = tf.clip_by_value(orig_coords, 0, max_coor) # borderMode==repeat
w = shape[1]
coords = tf.reshape(coords, [-1, 2])
coords = tf.matmul(coords, tf.constant([[w], [1]]))
coords = tf.reshape(coords, [-1] + shape2)
# bxh2xw2
coords = tf.clip_by_value(coords, 0, max_coor) # borderMode==repeat
batch_add = tf.range(tf.shape(img)[0]) * (shape[0] * shape[1])
batch_add = tf.reshape(batch_add, [-1, 1, 1]) # bx1x1
flat_coords = coords + batch_add
img = tf.reshape(img, [-1, shape[2]]) # bhw x c
sampled = tf.gather(img, flat_coords)
batch_index = tf.range(batch, dtype=tf.int32)
batch_index = tf.reshape(batch_index, [-1, 1, 1, 1])
batch_index = tf.tile(batch_index, [1, shape2[0], shape2[1], 1]) # bxh2xw2x1
indices = tf.concat([batch_index, coords], axis=3) # bxh2xw2x3
sampled = tf.gather_nd(img, indices)
return sampled
@layer_register(log_shape=True)
def ImageSample(inputs, borderMode='repeat'):
"""
Sample the template image using the given coordinate, by bilinear interpolation.
Sample the images using the given coordinates, by bilinear interpolation.
This was described in the paper:
`Spatial Transformer Networks <http://arxiv.org/abs/1506.02025>`_.
Args:
inputs (list): [template, coords]. template has shape NHWC.
coords has shape (N,H',W',2), where each pair of the last dimension is a (y, x) real-value
inputs (list): [images, coords]. images has shape NHWC.
coords has shape (N, H', W', 2), where each pair of the last dimension is a (y, x) real-value
coordinate.
borderMode: either "repeat" or "constant" (zero-filled)
Returns:
tf.Tensor: a tensor named ``output`` of shape (N,H',W',C).
tf.Tensor: a tensor named ``output`` of shape (N, H', W', C).
"""
# TODO borderValue
template, mapping = inputs
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
input_shape = template.get_shape().as_list()[1:]
image, mapping = inputs
assert image.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
input_shape = image.get_shape().as_list()[1:]
assert None not in input_shape, \
"Images in ImageSample layer must have fully-defined shape"
assert borderMode in ['repeat', 'constant']
......@@ -86,14 +77,10 @@ def ImageSample(inputs, borderMode='repeat'):
diffy, diffx = tf.split(diff, 2, 3)
neg_diffy, neg_diffx = tf.split(neg_diff, 2, 3)
# prod = tf.reduce_prod(diff, 3, keep_dims=True)
# diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
# tf.reduce_max(diff), diff], summarize=50)
ret = tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy,
sample(template, ucoor) * diffx * diffy,
sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled')
ret = tf.add_n([sample(image, lcoor) * neg_diffx * neg_diffy,
sample(image, ucoor) * diffx * diffy,
sample(image, lyux) * neg_diffy * diffx,
sample(image, uylx) * diffy * neg_diffx], name='sampled')
if borderMode == 'constant':
max_coor = tf.constant([input_shape[0] - 1, input_shape[1] - 1], dtype=tf.float32)
mask = tf.greater_equal(orig_mapping, 0.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment