Commit 7510f165 authored by Yuxin Wu's avatar Yuxin Wu

[STN] integer clip_by_value is so slow..

parent 7d3a24eb
......@@ -20,14 +20,14 @@ def sample(img, coords):
Return:
bxh2xw2xc image
"""
coords = tf.to_int32(coords)
shape = img.get_shape().as_list()[1:] # h, w, c
batch = tf.shape(img)[0]
shape2 = coords.get_shape().as_list()[1:3] # h2, w2
assert None not in shape2, coords.get_shape()
max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.int32)
max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.float32)
coords = tf.clip_by_value(coords, 0, max_coor) # borderMode==repeat
coords = tf.clip_by_value(coords, 0., max_coor) # borderMode==repeat
coords = tf.to_int32(coords)
batch_index = tf.range(batch, dtype=tf.int32)
batch_index = tf.reshape(batch_index, [-1, 1, 1, 1])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment