Commit daf368dc authored by ppwwyyxx's avatar ppwwyyxx

augmentation

parent 87adcc46
......@@ -16,6 +16,7 @@ from tensorpack.utils.summary import *
from tensorpack.utils.callback import *
from tensorpack.utils.validation_callback import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 500
......@@ -81,7 +82,15 @@ def get_config():
import cv2
dataset_train = dataset.Cifar10('train')
dataset_train = MapData(dataset_train, lambda img: cv2.resize(img, (24, 24)))
augmentor = imgaug.AugmentorList([
RandomCrop((24, 24)),
Flip(horiz=True),
BrightnessAdd(0.25),
Contrast((0.2,1.8)),
PerImageWhitening()
])
dataset_train = MapData(dataset_train, lambda img:
augmentor.augment(imgaug.Image(img)).arr)
dataset_train = BatchData(dataset_train, 128)
dataset_test = dataset.Cifar10('test')
dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24)))
......
......@@ -23,7 +23,6 @@ args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G:
global_step_var = get_global_step_var()
config = get_config_func()
config.get_model_func(config.inputs, is_training=False)
init = sessinit.SaverRestore(args.model)
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dump_train_config.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import tensorflow as tf
import imp
import os
from tensorpack.utils.utils import mkdir_p
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument(dest='output')
parser.add_argument('-n', '--number', help='number of images to take',
default=10, type=int)
args = parser.parse_args()
mkdir_p(args.output)
index = 0 # TODO: as an argument?
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
cnt = 0
for dp in config.dataset.get_data():
imgbatch = dp[index]
if cnt > args.number:
break
for bi, img in enumerate(imgbatch):
cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img * 255.0)
......@@ -7,6 +7,7 @@ from pkgutil import walk_packages
import os
import os.path
import dataset
import imgaug
__SKIP = ['dftools', 'dataset']
def global_import(name):
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
from pkgutil import walk_packages
def global_import(name):
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import abstractmethod, ABCMeta
from ...utils import get_rng
__all__ = ['Image', 'ImageAugmentor', 'AugmentorList']
class Image(object):
""" An image with attributes, for augmentor to operate on
Attributes (such as coordinates) have to be augmented acoordingly, if necessary
"""
def __init__(self, arr, coords=None):
self.arr = arr
self.coords = coords
class ImageAugmentor(object):
__metaclass__ = ABCMeta
def __init__(self):
self.rng = get_rng(self)
def _init(self, params=None):
self.rng = get_rng(self)
if params:
for k, v in params.iteritems():
if k != 'self':
setattr(self, k, v)
def augment(self, img):
"""
Note: will both modify `img` in-place and return `img`
"""
self._augment(img)
return img
@abstractmethod
def _augment(self, img):
"""
Augment the image in-place. Will always make it float32 array.
Args:
img: the input Image instance
img.arr must be of shape [h, w] or [h, w, c]
"""
def _rand_range(self, low=1.0, high=None, size=None):
if high is None:
low, high = 0, low
if size == None:
size = []
return low + self.rng.rand(*size) * (high - low)
class AugmentorList(ImageAugmentor):
"""
Augment by a list of augmentors
"""
def __init__(self, augmentors):
self.augs = augmentors
def _augment(self, img):
img.arr = img.arr.astype('float32') / 255.0
for aug in self.augs:
aug.augment(img)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: crop.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor
__all__ = ['RandomCrop']
class RandomCrop(ImageAugmentor):
def __init__(self, crop_shape):
"""
Randomly crop the image into a smaller one
Args:
crop_shape: shape in (h, w)
"""
self._init(locals())
def _augment(self, img):
orig_shape = img.arr.shape
h0 = self.rng.randint(0, orig_shape[0] - self.crop_shape[0])
w0 = self.rng.randint(0, orig_shape[1] - self.crop_shape[1])
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
if img.coords:
raise NotImplementedError()
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: imgproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor
import numpy as np
__all__ = ['BrightnessAdd', 'Contrast', 'PerImageWhitening']
class BrightnessAdd(ImageAugmentor):
"""
Randomly add a value within [-delta,delta], and clip in [0,1]
"""
def __init__(self, delta):
assert delta > 0
self._init(locals())
def _augment(self, img):
v = self._rand_range(-self.delta, self.delta)
img.arr += v
img.arr = np.clip(img.arr, 0, 1)
class Contrast(ImageAugmentor):
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
"""
def __init__(self, factor_range):
self._init(locals())
def _augment(self, img):
arr = img.arr
r = self._rand_range(*self.factor_range)
mean = np.mean(arr, axis=(0,1), keepdims=True)
img.arr = (arr - mean) * r + mean
class PerImageWhitening(ImageAugmentor):
"""
Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels))
"""
def __init__(self):
pass
def _augment(self, img):
mean = np.mean(img.arr, axis=(0,1), keepdims=True)
std = np.std(img.arr, axis=(0,1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape[:2])))
img.arr = (img.arr - mean) / std
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: noname.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor
import numpy as np
import cv2
__all__ = ['Flip']
class Flip(ImageAugmentor):
def __init__(self, horiz=False, vert=False, prob=0.5):
"""
Random flip.
Args:
horiz, vert: True/False
"""
if horiz and vert:
self.code = -1
elif horiz:
self.code = 1
elif vert:
self.code = 0
else:
raise RuntimeError("Are you kidding?")
self.prob = prob
self._init()
def _augment(self, img):
if self._rand_range() < self.prob:
img.arr = cv2.flip(img.arr, self.code)
if img.coords:
raise NotImplementedError()
......@@ -9,6 +9,7 @@ import time
import sys
from contextlib import contextmanager
import tensorflow as tf
import numpy as np
import collections
import logger
......@@ -103,3 +104,6 @@ def get_global_step_var():
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
return global_step_var
def get_rng(self):
return np.random.RandomState()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment