Commit c22edc77 authored by Yuxin Wu's avatar Yuxin Wu

centerpatch dataflow

parent d382ea9d
......@@ -4,7 +4,10 @@
from .base import ImageAugmentor
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop']
import numpy as np
from abc import abstractmethod
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', 'CenterPaste']
class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """
......@@ -47,3 +50,47 @@ class FixedCrop(ImageAugmentor):
self.rangex[0]:self.rangex[1]]
if img.coords:
raise NotImplementedError()
class BackgroundFiller(object):
@abstractmethod
def fill(background_shape, img):
"""
return a proper background image of background_shape, given img
"""
class ConstantBackgroundFiller(BackgroundFiller):
def __init__(self, value):
self.value = value
def fill(self, background_shape, img):
assert img.ndim in [3, 1]
if img.ndim == 3:
return_shape = background_shape + (3,)
else:
return_shape = background_shape
return np.zeros(return_shape) + self.value
class CenterPaste(ImageAugmentor):
"""
Paste the image onto center of a background
"""
def __init__(self, background_shape, background_filler=None):
if background_filler is None:
background_filler = ConstantBackgroundFiller(0)
self._init(locals())
def _augment(self, img):
img_shape = img.arr.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
background = self.background_filler.fill(
self.background_shape, img.arr)
h0 = (self.background_shape[0] - img_shape[0]) * 0.5
w0 = (self.background_shape[1] - img_shape[1]) * 0.5
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr
img.arr = background
if img.coords:
raise NotImplementedError()
......@@ -11,13 +11,14 @@ class BrightnessAdd(ImageAugmentor):
"""
Randomly add a value within [-delta,delta], and clip in [0,255]
"""
def __init__(self, delta):
def __init__(self, delta, clip=True):
assert delta > 0
self._init(locals())
def _augment(self, img):
v = self._rand_range(-self.delta, self.delta)
img.arr += v
if self.clip:
img.arr = np.clip(img.arr, 0, 255)
class Contrast(ImageAugmentor):
......@@ -25,7 +26,7 @@ class Contrast(ImageAugmentor):
Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
"""
def __init__(self, factor_range):
def __init__(self, factor_range, clip=True):
self._init(locals())
def _augment(self, img):
......@@ -33,6 +34,7 @@ class Contrast(ImageAugmentor):
r = self._rand_range(*self.factor_range)
mean = np.mean(arr, axis=(0,1), keepdims=True)
img.arr = (arr - mean) * r + mean
if self.clip:
img.arr = np.clip(img.arr, 0, 255)
class MeanVarianceNormalize(ImageAugmentor):
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: nonl.py
# File: nonlin.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
......@@ -8,7 +8,7 @@ from copy import copy
from ._common import *
__all__ = ['Maxout', 'PReLU']
__all__ = ['Maxout', 'PReLU', 'LeakyReLU']
@layer_register()
def Maxout(x, num_unit):
......@@ -28,3 +28,12 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
return x * 0.5
else:
return tf.mul(x, 0.5, name=name)
@layer_register()
def LeakyReLU(x, alpha, name=None):
alpha = float(alpha)
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
if name is None:
return x * 0.5
else:
return tf.mul(x, 0.5, name=name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment