Commit b81c2263 authored by Yuxin Wu's avatar Yuxin Wu

randomcroprandomshape

parent 32feff4e
......@@ -38,9 +38,6 @@ class Model(ModelDesc):
keep_prob = tf.constant(0.5 if is_training else 1.0)
if is_training:
image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=6, enqueue_many=True)
tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller
......
......@@ -81,7 +81,6 @@ class Model(ModelDesc):
l = c2 + l
return l
l = conv('conv0', image, 16, 1)
l = BatchNorm('bn0', l, is_training)
l = tf.nn.relu(l)
......
......@@ -5,14 +5,16 @@
import sys
import cv2
from . import AugmentorList, Flip, GaussianDeform, Image
from . import AugmentorList, Image
from .crop import *
anchors = [(0.2, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([
#Contrast((0.2,1.8)),
#Flip(horiz=True),
GaussianDeform(anchors, (360,480), 1, randrange=10)
#GaussianDeform(anchors, (360,480), 1, randrange=10)
RandomCropRandomShape(0.3)
])
while True:
......
......@@ -3,12 +3,12 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor
from ...utils.rect import Rect
from six.moves import range
import numpy as np
from abc import abstractmethod
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', 'CenterPaste',
'ConstantBackgroundFiller']
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', 'RandomCropRandomShape']
class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """
......@@ -44,79 +44,76 @@ class CenterCrop(ImageAugmentor):
class FixedCrop(ImageAugmentor):
""" Crop a rectangle at a given location"""
def __init__(self, rangex, rangey):
def __init__(self, rect):
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
:param rangex: like (xmin, xmax).
:param rangey: like (ymin, ymax).
:param rect: a `Rect` instance
"""
self._init(locals())
def _augment(self, img):
orig_shape = img.arr.shape
img.arr = img.arr[self.rangey[0]:self.rangey[1],
self.rangex[0]:self.rangex[1]]
img.arr = img.arr[self.rect.y0: self.rect.y1+1,
self.rect.x0: self.rect.x0+1]
if img.coords:
raise NotImplementedError()
class BackgroundFiller(object):
""" Base class for all BackgroundFiller"""
def fill(self, background_shape, img):
def perturb_BB(image_shape, bb, max_pertub_pixel,
rng=None, max_aspect_ratio_diff=0.3,
max_try=100):
"""
Return a proper background image of background_shape, given img
:param background_shape: a shape of [h, w]
:param img: an image
:returns: a background image
"""
return self._fill(background_shape, img)
@abstractmethod
def _fill(self, background_shape, img):
pass
class ConstantBackgroundFiller(BackgroundFiller):
""" Fill the background by a constant """
def __init__(self, value):
"""
:param value: the value to fill the background.
Perturb a bounding box.
:param image_shape: [h, w]
:param bb: a `Rect` instance
:param max_pertub_pixel: pertubation on each coordinate
:param max_aspect_ratio_diff: result can't have an aspect ratio too different from the original
:param max_try: if cannot find a valid bounding box, return the original
:returns: new bounding box
"""
self.value = value
def _fill(self, background_shape, img):
assert img.ndim in [3, 1]
if img.ndim == 3:
return_shape = background_shape + (3,)
else:
return_shape = background_shape
return np.zeros(return_shape) + self.value
class CenterPaste(ImageAugmentor):
orig_ratio = bb.h * 1.0 / bb.w
if rng is None:
rng = np.random.RandomState()
for _ in range(max_try):
p = rng.randint(-max_pertub_pixel, max_pertub_pixel, [4])
newbb = bb.copy()
newbb.x += p[0]
newbb.y += p[1]
newx1 = bb.x1 + p[2]
newy1 = bb.y1 + p[3]
newbb.w = newx1 - newbb.x
newbb.h = newy1 - newbb.y
if not newbb.validate(image_shape):
continue
new_ratio = newbb.h * 1.0 / newbb.w
diff = abs(new_ratio - orig_ratio)
if diff / orig_ratio > max_aspect_ratio_diff:
continue
return newbb
return bb
class RandomCropRandomShape(ImageAugmentor):
"""
Paste the image onto the center of a background canvas.
Crop a box around a bounding box
"""
def __init__(self, background_shape, background_filler=None):
def __init__(self, perturb_ratio, max_aspect_ratio_diff=0.3):
"""
:param background_shape: shape of the background canvas.
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler.
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param max_aspect_ratio_diff: keep aspect ratio within the range
"""
if background_filler is None:
background_filler = ConstantBackgroundFiller(0)
self._init(locals())
def _augment(self, img):
img_shape = img.arr.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
background = self.background_filler.fill(
self.background_shape, img.arr)
h0 = (self.background_shape[0] - img_shape[0]) * 0.5
w0 = (self.background_shape[1] - img_shape[1]) * 0.5
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr
img.arr = background
shape = img.arr.shape[:2]
box = Rect(0, 0, shape[1] - 1, shape[0] - 1)
dist = self.perturb_ratio * np.sqrt(shape[0]*shape[1])
newbox = perturb_BB(shape, box, dist,
self.rng, self.max_aspect_ratio_diff)
img.arr = newbox.roi(img.arr)
if img.coords:
raise NotImplementedError()
if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: paste.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ImageAugmentor
from abc import abstractmethod
import numpy as np
__all__ = [ 'CenterPaste', 'ConstantBackgroundFiller']
class BackgroundFiller(object):
""" Base class for all BackgroundFiller"""
def fill(self, background_shape, img):
"""
Return a proper background image of background_shape, given img
:param background_shape: a shape of [h, w]
:param img: an image
:returns: a background image
"""
return self._fill(background_shape, img)
@abstractmethod
def _fill(self, background_shape, img):
pass
class ConstantBackgroundFiller(BackgroundFiller):
""" Fill the background by a constant """
def __init__(self, value):
"""
:param value: the value to fill the background.
"""
self.value = value
def _fill(self, background_shape, img):
assert img.ndim in [3, 1]
if img.ndim == 3:
return_shape = background_shape + (3,)
else:
return_shape = background_shape
return np.zeros(return_shape) + self.value
class CenterPaste(ImageAugmentor):
"""
Paste the image onto the center of a background canvas.
"""
def __init__(self, background_shape, background_filler=None):
"""
:param background_shape: shape of the background canvas.
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler.
"""
if background_filler is None:
background_filler = ConstantBackgroundFiller(0)
self._init(locals())
def _augment(self, img):
img_shape = img.arr.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
background = self.background_filler.fill(
self.background_shape, img.arr)
h0 = (self.background_shape[0] - img_shape[0]) * 0.5
w0 = (self.background_shape[1] - img_shape[1]) * 0.5
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr
img.arr = background
if img.coords:
raise NotImplementedError()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: rect.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
class Rect(object):
"""
A Rectangle.
Note that x1 = x+w, not x+w-1 or something
"""
__slots__ = ['x', 'y', 'w', 'h']
def __init__(self, x=0, y=0, w=0, h=0):
self.x = x
self.y = y
self.w = w
self.h = h
assert min(self.x, self.y, self.w, self.h) >= 0
@property
def x0(self):
return self.x
@property
def y0(self):
return self.y
@property
def x1(self):
return self.x + self.w
@property
def y1(self):
return self.y + self.h
def copy(self):
new = type(self)()
for i in self.__slots__:
setattr(new, i, getattr(self, i))
return new
def __str__(self):
return 'Rect(x={}, y={}, w={}, h={})'.format(self.x, self.y, self.w, self.h)
def area(self):
return self.w * self.h
def validate(self, shape=None):
"""
Is a valid bounding box within this shape
:param shape: [h, w]
:returns: boolean
"""
if min(self.x, self.y) < 0:
return False
if min(self.w, self.h) <= 0:
return False
if shape is None:
return True
if self.x1 > shape[1] - 1:
return False
if self.y1 > shape[0] - 1:
return False
return True
def roi(self, img):
assert self.validate(img.shape[:2])
return img[self.y0:self.y1+1, self.x0:self.x1+1]
__repr__ = __str__
......@@ -4,6 +4,7 @@
import os, sys
from contextlib import contextmanager
from datetime import datetime
import time
import collections
import numpy as np
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment