Commit 22c0f6ac authored by Yuxin Wu's avatar Yuxin Wu

random paste and image sampling

parent 0f6289a4
......@@ -115,7 +115,6 @@ class BatchDataByShape(BatchData):
def get_data(self):
for dp in self.ds.get_data():
shp = dp[self.idx].shape
print(shp, len(self.holder))
holder = self.holder[shp]
holder.append(dp)
if len(holder) == self.batch_size:
......
......@@ -15,8 +15,6 @@ from ..base import RNGDataFlow
__all__ = ['Mnist']
""" This file is mostly copied from tensorflow example """
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
......
......@@ -7,7 +7,7 @@ from .base import ImageAugmentor
import numpy as np
import cv2
__all__ = ['JpegNoise', 'GaussianNoise']
__all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class JpegNoise(ImageAugmentor):
def __init__(self, quality_range=(40, 100)):
......
......@@ -8,7 +8,8 @@ from .base import ImageAugmentor
from abc import abstractmethod
import numpy as np
__all__ = [ 'CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller']
__all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
'RandomPaste']
class BackgroundFiller(object):
......@@ -36,7 +37,7 @@ class ConstantBackgroundFiller(BackgroundFiller):
self.value = value
def _fill(self, background_shape, img):
assert img.ndim in [3, 1]
assert img.ndim in [3, 2]
if img.ndim == 3:
return_shape = background_shape + (3,)
else:
......@@ -63,12 +64,30 @@ class CenterPaste(ImageAugmentor):
background = self.background_filler.fill(
self.background_shape, img)
h0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
w0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img
img = background
return img
y0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
x0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[y0:y0+img_shape[0], x0:x0+img_shape[1]] = img
return background
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class RandomPaste(CenterPaste):
"""
Randomly paste the image onto a background convas
"""
def _get_augment_params(self, img):
img_shape = img.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
y0 = self._rand_range(self.background_shape[0] - img_shape[0])
x0 = self._rand_range(self.background_shape[1] - img_shape[1])
return int(x0), int(y0)
def _augment(self, img, loc):
x0, y0 = loc
img_shape = img.shape[:2]
background = self.background_filler.fill(
self.background_shape, img)
background[y0:y0+img_shape[0], x0:x0+img_shape[1]] = img
return background
......@@ -18,11 +18,11 @@ def sample(img, coords):
:param coords: bxh2xw2x2 (y, x) integer
:return: bxh2xw2xc image
"""
coords = tf.cast(coords, tf.int32)
shape = img.get_shape().as_list()[1:]
shape2 = coords.get_shape().as_list()[1:3]
max_coor = tf.constant([shape[0] - 1, shape[1] - 1])
coords = tf.minimum(coords, max_coor)
coords = tf.maximum(coords, tf.constant(0))
max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.int32)
coords = tf.clip_by_value(coords, 0, max_coor)
w = shape[1]
coords = tf.reshape(coords, [-1, 2])
......@@ -46,8 +46,8 @@ def ImageSample(inputs):
It mimics the same behavior described in:
`Spatial Transformer Networks <http://arxiv.org/abs/1506.02025>`_.
:param input: [template, mapping]. template of shape NHWC. mapping of
shape NHW2, where each pair of the last dimension is a (y, x) real-value
:param input: [template, mapping]. template of shape NHWC.
mapping of shape NHW2, where each pair of the last dimension is a (y, x) real-value
coordinate.
:returns: a NHWC output tensor.
"""
......@@ -55,13 +55,10 @@ def ImageSample(inputs):
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
mapping = tf.maximum(mapping, 0.0)
lcoor = tf.cast(mapping, tf.int32) # floor
lcoor = tf.floor(mapping)
ucoor = lcoor + 1
# has to cast to int32 and then cast back
# tf.floor have gradient 1 w.r.t input
# TODO bug fixed in #951
diff = mapping - tf.cast(lcoor, tf.float32)
diff = mapping - lcoor
neg_diff = 1.0 - diff #bxh2xw2x2
lcoory, lcoorx = tf.split(3, 2, lcoor)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment