Commit 22c0f6ac authored by Yuxin Wu's avatar Yuxin Wu

random paste and image sampling

parent 0f6289a4
...@@ -115,7 +115,6 @@ class BatchDataByShape(BatchData): ...@@ -115,7 +115,6 @@ class BatchDataByShape(BatchData):
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
shp = dp[self.idx].shape shp = dp[self.idx].shape
print(shp, len(self.holder))
holder = self.holder[shp] holder = self.holder[shp]
holder.append(dp) holder.append(dp)
if len(holder) == self.batch_size: if len(holder) == self.batch_size:
......
...@@ -15,8 +15,6 @@ from ..base import RNGDataFlow ...@@ -15,8 +15,6 @@ from ..base import RNGDataFlow
__all__ = ['Mnist'] __all__ = ['Mnist']
""" This file is mostly copied from tensorflow example """
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory): def maybe_download(filename, work_directory):
......
...@@ -7,7 +7,7 @@ from .base import ImageAugmentor ...@@ -7,7 +7,7 @@ from .base import ImageAugmentor
import numpy as np import numpy as np
import cv2 import cv2
__all__ = ['JpegNoise', 'GaussianNoise'] __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class JpegNoise(ImageAugmentor): class JpegNoise(ImageAugmentor):
def __init__(self, quality_range=(40, 100)): def __init__(self, quality_range=(40, 100)):
......
...@@ -8,7 +8,8 @@ from .base import ImageAugmentor ...@@ -8,7 +8,8 @@ from .base import ImageAugmentor
from abc import abstractmethod from abc import abstractmethod
import numpy as np import numpy as np
__all__ = [ 'CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller'] __all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
'RandomPaste']
class BackgroundFiller(object): class BackgroundFiller(object):
...@@ -36,7 +37,7 @@ class ConstantBackgroundFiller(BackgroundFiller): ...@@ -36,7 +37,7 @@ class ConstantBackgroundFiller(BackgroundFiller):
self.value = value self.value = value
def _fill(self, background_shape, img): def _fill(self, background_shape, img):
assert img.ndim in [3, 1] assert img.ndim in [3, 2]
if img.ndim == 3: if img.ndim == 3:
return_shape = background_shape + (3,) return_shape = background_shape + (3,)
else: else:
...@@ -63,12 +64,30 @@ class CenterPaste(ImageAugmentor): ...@@ -63,12 +64,30 @@ class CenterPaste(ImageAugmentor):
background = self.background_filler.fill( background = self.background_filler.fill(
self.background_shape, img) self.background_shape, img)
h0 = int((self.background_shape[0] - img_shape[0]) * 0.5) y0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
w0 = int((self.background_shape[1] - img_shape[1]) * 0.5) x0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img background[y0:y0+img_shape[0], x0:x0+img_shape[1]] = img
img = background return background
return img
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
class RandomPaste(CenterPaste):
"""
Randomly paste the image onto a background convas
"""
def _get_augment_params(self, img):
img_shape = img.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
y0 = self._rand_range(self.background_shape[0] - img_shape[0])
x0 = self._rand_range(self.background_shape[1] - img_shape[1])
return int(x0), int(y0)
def _augment(self, img, loc):
x0, y0 = loc
img_shape = img.shape[:2]
background = self.background_filler.fill(
self.background_shape, img)
background[y0:y0+img_shape[0], x0:x0+img_shape[1]] = img
return background
...@@ -18,11 +18,11 @@ def sample(img, coords): ...@@ -18,11 +18,11 @@ def sample(img, coords):
:param coords: bxh2xw2x2 (y, x) integer :param coords: bxh2xw2x2 (y, x) integer
:return: bxh2xw2xc image :return: bxh2xw2xc image
""" """
coords = tf.cast(coords, tf.int32)
shape = img.get_shape().as_list()[1:] shape = img.get_shape().as_list()[1:]
shape2 = coords.get_shape().as_list()[1:3] shape2 = coords.get_shape().as_list()[1:3]
max_coor = tf.constant([shape[0] - 1, shape[1] - 1]) max_coor = tf.constant([shape[0] - 1, shape[1] - 1], dtype=tf.int32)
coords = tf.minimum(coords, max_coor) coords = tf.clip_by_value(coords, 0, max_coor)
coords = tf.maximum(coords, tf.constant(0))
w = shape[1] w = shape[1]
coords = tf.reshape(coords, [-1, 2]) coords = tf.reshape(coords, [-1, 2])
...@@ -46,8 +46,8 @@ def ImageSample(inputs): ...@@ -46,8 +46,8 @@ def ImageSample(inputs):
It mimics the same behavior described in: It mimics the same behavior described in:
`Spatial Transformer Networks <http://arxiv.org/abs/1506.02025>`_. `Spatial Transformer Networks <http://arxiv.org/abs/1506.02025>`_.
:param input: [template, mapping]. template of shape NHWC. mapping of :param input: [template, mapping]. template of shape NHWC.
shape NHW2, where each pair of the last dimension is a (y, x) real-value mapping of shape NHW2, where each pair of the last dimension is a (y, x) real-value
coordinate. coordinate.
:returns: a NHWC output tensor. :returns: a NHWC output tensor.
""" """
...@@ -55,13 +55,10 @@ def ImageSample(inputs): ...@@ -55,13 +55,10 @@ def ImageSample(inputs):
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4 assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
mapping = tf.maximum(mapping, 0.0) mapping = tf.maximum(mapping, 0.0)
lcoor = tf.cast(mapping, tf.int32) # floor lcoor = tf.floor(mapping)
ucoor = lcoor + 1 ucoor = lcoor + 1
# has to cast to int32 and then cast back diff = mapping - lcoor
# tf.floor have gradient 1 w.r.t input
# TODO bug fixed in #951
diff = mapping - tf.cast(lcoor, tf.float32)
neg_diff = 1.0 - diff #bxh2xw2x2 neg_diff = 1.0 - diff #bxh2xw2x2
lcoory, lcoorx = tf.split(3, 2, lcoor) lcoory, lcoorx = tf.split(3, 2, lcoor)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment