Commit 7510f165 authored by Yuxin Wu's avatar Yuxin Wu

[STN] integer clip_by_value is so slow..

parent 7d3a24eb
...@@ -20,14 +20,14 @@ def sample(img, coords): ...@@ -20,14 +20,14 @@ def sample(img, coords):
Return: Return:
bxh2xw2xc image bxh2xw2xc image
""" """
coords = tf.to_int32(coords)
shape = img.get_shape().as_list()[1:] # h, w, c shape = img.get_shape().as_list()[1:] # h, w, c
batch = tf.shape(img)[0] batch = tf.shape(img)[0]
shape2 = coords.get_shape().as_list()[1:3] # h2, w2 shape2 = coords.get_shape().as_list()[1:3] # h2, w2
assert None not in shape2, coords.get_shape() assert None not in shape2, coords.get_shape()
max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.int32) max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.float32)
coords = tf.clip_by_value(coords, 0, max_coor) # borderMode==repeat coords = tf.clip_by_value(coords, 0., max_coor) # borderMode==repeat
coords = tf.to_int32(coords)
batch_index = tf.range(batch, dtype=tf.int32) batch_index = tf.range(batch, dtype=tf.int32)
batch_index = tf.reshape(batch_index, [-1, 1, 1, 1]) batch_index = tf.reshape(batch_index, [-1, 1, 1, 1])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment