Commit b81c2263 authored by Yuxin Wu's avatar Yuxin Wu

randomcroprandomshape

parent 32feff4e
...@@ -38,9 +38,6 @@ class Model(ModelDesc): ...@@ -38,9 +38,6 @@ class Model(ModelDesc):
keep_prob = tf.constant(0.5 if is_training else 1.0) keep_prob = tf.constant(0.5 if is_training else 1.0)
if is_training: if is_training:
image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=6, enqueue_many=True)
tf.image_summary("train_image", image, 10) tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller image = image / 4.0 # just to make range smaller
......
...@@ -81,7 +81,6 @@ class Model(ModelDesc): ...@@ -81,7 +81,6 @@ class Model(ModelDesc):
l = c2 + l l = c2 + l
return l return l
l = conv('conv0', image, 16, 1) l = conv('conv0', image, 16, 1)
l = BatchNorm('bn0', l, is_training) l = BatchNorm('bn0', l, is_training)
l = tf.nn.relu(l) l = tf.nn.relu(l)
......
...@@ -5,14 +5,16 @@ ...@@ -5,14 +5,16 @@
import sys import sys
import cv2 import cv2
from . import AugmentorList, Flip, GaussianDeform, Image from . import AugmentorList, Image
from .crop import *
anchors = [(0.2, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)] anchors = [(0.2, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([ augmentors = AugmentorList([
#Contrast((0.2,1.8)), #Contrast((0.2,1.8)),
#Flip(horiz=True), #Flip(horiz=True),
GaussianDeform(anchors, (360,480), 1, randrange=10) #GaussianDeform(anchors, (360,480), 1, randrange=10)
RandomCropRandomShape(0.3)
]) ])
while True: while True:
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor from .base import ImageAugmentor
from ...utils.rect import Rect
from six.moves import range
import numpy as np import numpy as np
from abc import abstractmethod
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', 'CenterPaste', __all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', 'RandomCropRandomShape']
'ConstantBackgroundFiller']
class RandomCrop(ImageAugmentor): class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """ """ Randomly crop the image into a smaller one """
...@@ -44,79 +44,76 @@ class CenterCrop(ImageAugmentor): ...@@ -44,79 +44,76 @@ class CenterCrop(ImageAugmentor):
class FixedCrop(ImageAugmentor): class FixedCrop(ImageAugmentor):
""" Crop a rectangle at a given location""" """ Crop a rectangle at a given location"""
def __init__(self, rangex, rangey): def __init__(self, rect):
""" """
Two arguments defined the range in both axes to crop, min inclued, max excluded. Two arguments defined the range in both axes to crop, min inclued, max excluded.
:param rangex: like (xmin, xmax). :param rect: a `Rect` instance
:param rangey: like (ymin, ymax).
""" """
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img):
orig_shape = img.arr.shape orig_shape = img.arr.shape
img.arr = img.arr[self.rangey[0]:self.rangey[1], img.arr = img.arr[self.rect.y0: self.rect.y1+1,
self.rangex[0]:self.rangex[1]] self.rect.x0: self.rect.x0+1]
if img.coords: if img.coords:
raise NotImplementedError() raise NotImplementedError()
class BackgroundFiller(object): def perturb_BB(image_shape, bb, max_pertub_pixel,
""" Base class for all BackgroundFiller""" rng=None, max_aspect_ratio_diff=0.3,
def fill(self, background_shape, img): max_try=100):
""" """
Return a proper background image of background_shape, given img Perturb a bounding box.
:param image_shape: [h, w]
:param background_shape: a shape of [h, w] :param bb: a `Rect` instance
:param img: an image :param max_pertub_pixel: pertubation on each coordinate
:returns: a background image :param max_aspect_ratio_diff: result can't have an aspect ratio too different from the original
""" :param max_try: if cannot find a valid bounding box, return the original
return self._fill(background_shape, img) :returns: new bounding box
"""
@abstractmethod orig_ratio = bb.h * 1.0 / bb.w
def _fill(self, background_shape, img): if rng is None:
pass rng = np.random.RandomState()
for _ in range(max_try):
class ConstantBackgroundFiller(BackgroundFiller): p = rng.randint(-max_pertub_pixel, max_pertub_pixel, [4])
""" Fill the background by a constant """ newbb = bb.copy()
def __init__(self, value): newbb.x += p[0]
""" newbb.y += p[1]
:param value: the value to fill the background. newx1 = bb.x1 + p[2]
""" newy1 = bb.y1 + p[3]
self.value = value newbb.w = newx1 - newbb.x
newbb.h = newy1 - newbb.y
def _fill(self, background_shape, img): if not newbb.validate(image_shape):
assert img.ndim in [3, 1] continue
if img.ndim == 3: new_ratio = newbb.h * 1.0 / newbb.w
return_shape = background_shape + (3,) diff = abs(new_ratio - orig_ratio)
else: if diff / orig_ratio > max_aspect_ratio_diff:
return_shape = background_shape continue
return np.zeros(return_shape) + self.value return newbb
return bb
class CenterPaste(ImageAugmentor):
class RandomCropRandomShape(ImageAugmentor):
""" """
Paste the image onto the center of a background canvas. Crop a box around a bounding box
""" """
def __init__(self, background_shape, background_filler=None): def __init__(self, perturb_ratio, max_aspect_ratio_diff=0.3):
""" """
:param background_shape: shape of the background canvas. :param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler. :param max_aspect_ratio_diff: keep aspect ratio within the range
""" """
if background_filler is None:
background_filler = ConstantBackgroundFiller(0)
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img):
img_shape = img.arr.shape[:2] shape = img.arr.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1] box = Rect(0, 0, shape[1] - 1, shape[0] - 1)
dist = self.perturb_ratio * np.sqrt(shape[0]*shape[1])
background = self.background_filler.fill( newbox = perturb_BB(shape, box, dist,
self.background_shape, img.arr) self.rng, self.max_aspect_ratio_diff)
h0 = (self.background_shape[0] - img_shape[0]) * 0.5
w0 = (self.background_shape[1] - img_shape[1]) * 0.5 img.arr = newbox.roi(img.arr)
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr
img.arr = background
if img.coords: if img.coords:
raise NotImplementedError() raise NotImplementedError()
if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: paste.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ImageAugmentor
from abc import abstractmethod
import numpy as np
__all__ = [ 'CenterPaste', 'ConstantBackgroundFiller']
class BackgroundFiller(object):
""" Base class for all BackgroundFiller"""
def fill(self, background_shape, img):
"""
Return a proper background image of background_shape, given img
:param background_shape: a shape of [h, w]
:param img: an image
:returns: a background image
"""
return self._fill(background_shape, img)
@abstractmethod
def _fill(self, background_shape, img):
pass
class ConstantBackgroundFiller(BackgroundFiller):
""" Fill the background by a constant """
def __init__(self, value):
"""
:param value: the value to fill the background.
"""
self.value = value
def _fill(self, background_shape, img):
assert img.ndim in [3, 1]
if img.ndim == 3:
return_shape = background_shape + (3,)
else:
return_shape = background_shape
return np.zeros(return_shape) + self.value
class CenterPaste(ImageAugmentor):
"""
Paste the image onto the center of a background canvas.
"""
def __init__(self, background_shape, background_filler=None):
"""
:param background_shape: shape of the background canvas.
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler.
"""
if background_filler is None:
background_filler = ConstantBackgroundFiller(0)
self._init(locals())
def _augment(self, img):
img_shape = img.arr.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
background = self.background_filler.fill(
self.background_shape, img.arr)
h0 = (self.background_shape[0] - img_shape[0]) * 0.5
w0 = (self.background_shape[1] - img_shape[1]) * 0.5
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr
img.arr = background
if img.coords:
raise NotImplementedError()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: rect.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
class Rect(object):
"""
A Rectangle.
Note that x1 = x+w, not x+w-1 or something
"""
__slots__ = ['x', 'y', 'w', 'h']
def __init__(self, x=0, y=0, w=0, h=0):
self.x = x
self.y = y
self.w = w
self.h = h
assert min(self.x, self.y, self.w, self.h) >= 0
@property
def x0(self):
return self.x
@property
def y0(self):
return self.y
@property
def x1(self):
return self.x + self.w
@property
def y1(self):
return self.y + self.h
def copy(self):
new = type(self)()
for i in self.__slots__:
setattr(new, i, getattr(self, i))
return new
def __str__(self):
return 'Rect(x={}, y={}, w={}, h={})'.format(self.x, self.y, self.w, self.h)
def area(self):
return self.w * self.h
def validate(self, shape=None):
"""
Is a valid bounding box within this shape
:param shape: [h, w]
:returns: boolean
"""
if min(self.x, self.y) < 0:
return False
if min(self.w, self.h) <= 0:
return False
if shape is None:
return True
if self.x1 > shape[1] - 1:
return False
if self.y1 > shape[0] - 1:
return False
return True
def roi(self, img):
assert self.validate(img.shape[:2])
return img[self.y0:self.y1+1, self.x0:self.x1+1]
__repr__ = __str__
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import os, sys import os, sys
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime
import time import time
import collections import collections
import numpy as np import numpy as np
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment