Commit 2ef9669d authored by Yuxin Wu's avatar Yuxin Wu

borderMode in STN

parent 3f19c1b8
...@@ -24,6 +24,9 @@ class JpegNoise(ImageAugmentor): ...@@ -24,6 +24,9 @@ class JpegNoise(ImageAugmentor):
class GaussianNoise(ImageAugmentor): class GaussianNoise(ImageAugmentor):
def __init__(self, scale=10, clip=True): def __init__(self, scale=10, clip=True):
"""
Add a gaussian noise of the same shape to img.
"""
super(GaussianNoise, self).__init__() super(GaussianNoise, self).__init__()
self._init(locals()) self._init(locals())
...@@ -38,6 +41,9 @@ class GaussianNoise(ImageAugmentor): ...@@ -38,6 +41,9 @@ class GaussianNoise(ImageAugmentor):
class SaltPepperNoise(ImageAugmentor): class SaltPepperNoise(ImageAugmentor):
def __init__(self, white_prob=0.05, black_prob=0.05): def __init__(self, white_prob=0.05, black_prob=0.05):
""" Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels.
"""
assert white_prob + black_prob <= 1, "Sum of probabilities cannot be greater than 1" assert white_prob + black_prob <= 1, "Sum of probabilities cannot be greater than 1"
super(SaltPepperNoise, self).__init__() super(SaltPepperNoise, self).__init__()
self._init(locals()) self._init(locals())
......
...@@ -10,19 +10,21 @@ from ._common import layer_register ...@@ -10,19 +10,21 @@ from ._common import layer_register
__all__ = ['ImageSample'] __all__ = ['ImageSample']
# XXX TODO ugly. # XXX TODO ugly.
# really need to fix this after tensorflow supports multiple indexing # really need to fix this after tensorflow supports advanced indexing
# See github:tensorflow#418,#206 # See github:tensorflow#418,#206
def sample(img, coords): def sample(img, coords, borderMode):
""" """
:param img: bxhxwxc :param img: bxhxwxc
:param coords: bxh2xw2x2 (y, x) integer :param coords: bxh2xw2x2 (y, x) floating point (but is actually holding integer)
:return: bxh2xw2xc image :return: bxh2xw2xc image
""" """
coords = tf.cast(coords, tf.int32) orig_coords = tf.cast(coords, tf.int32)
shape = img.get_shape().as_list()[1:] shape = img.get_shape().as_list()[1:]
shape2 = coords.get_shape().as_list()[1:3] shape2 = coords.get_shape().as_list()[1:3]
max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.int32) max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.int32)
coords = tf.clip_by_value(coords, 0, max_coor)
# clip_by_value actually supports broadcasting
coords = tf.clip_by_value(orig_coords, 0, max_coor) # borderMode==repeat
w = shape[1] w = shape[1]
coords = tf.reshape(coords, [-1, 2]) coords = tf.reshape(coords, [-1, 2])
...@@ -37,10 +39,18 @@ def sample(img, coords): ...@@ -37,10 +39,18 @@ def sample(img, coords):
img = tf.reshape(img, [-1, shape[2]]) #bhw x c img = tf.reshape(img, [-1, shape[2]]) #bhw x c
sampled = tf.gather(img, flat_coords) sampled = tf.gather(img, flat_coords)
if borderMode == 'constant':
mask = tf.less_equal(orig_coords, max_coor)
mask2 = tf.greater_equal(orig_coords, 0)
mask = tf.logical_and(mask, mask2) #bxh2xw2x2
mask = tf.reduce_all(mask, [3]) # bxh2xw2 boolean
mask = tf.expand_dims(mask, 3)
sampled = sampled * tf.cast(mask, tf.float32)
return sampled return sampled
@layer_register() @layer_register()
def ImageSample(inputs): def ImageSample(inputs, borderMode='repeat'):
""" """
Sample the template image, using the given coordinate, by bilinear interpolation. Sample the template image, using the given coordinate, by bilinear interpolation.
It mimics the same behavior described in: It mimics the same behavior described in:
...@@ -49,10 +59,12 @@ def ImageSample(inputs): ...@@ -49,10 +59,12 @@ def ImageSample(inputs):
:param input: [template, mapping]. template of shape NHWC. :param input: [template, mapping]. template of shape NHWC.
mapping of shape NHW2, where each pair of the last dimension is a (y, x) real-value mapping of shape NHW2, where each pair of the last dimension is a (y, x) real-value
coordinate. coordinate.
:param borderMode: either 'repeat' or 'constant' (0)
:returns: a NHWC output tensor. :returns: a NHWC output tensor.
""" """
template, mapping = inputs template, mapping = inputs
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4 assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
assert borderMode in ['repeat', 'constant']
mapping = tf.maximum(mapping, 0.0) mapping = tf.maximum(mapping, 0.0)
lcoor = tf.floor(mapping) lcoor = tf.floor(mapping)
...@@ -72,13 +84,12 @@ def ImageSample(inputs): ...@@ -72,13 +84,12 @@ def ImageSample(inputs):
#prod = tf.reduce_prod(diff, 3, keep_dims=True) #prod = tf.reduce_prod(diff, 3, keep_dims=True)
#diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod), #diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#tf.reduce_max(diff), diff], #tf.reduce_max(diff), diff], summarize=50)
#summarize=50)
return tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy, return tf.add_n([sample(template, lcoor, borderMode) * neg_diffx * neg_diffy,
sample(template, ucoor) * diffx * diffy, sample(template, ucoor, borderMode) * diffx * diffy,
sample(template, lyux) * neg_diffy * diffx, sample(template, lyux, borderMode) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled') sample(template, uylx, borderMode) * diffy * neg_diffx], name='sampled')
from ._test import TestModel from ._test import TestModel
class TestSample(TestModel): class TestSample(TestModel):
......
...@@ -46,7 +46,8 @@ def get_global_step_var(): ...@@ -46,7 +46,8 @@ def get_global_step_var():
assert scope.name == '', \ assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!" "Creating global_step_var under a variable scope would cause problems!"
var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[], var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[],
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(),
trainable=False, dtype=tf.int32)
return var return var
def get_global_step(): def get_global_step():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment