Commit daf368dc authored by ppwwyyxx's avatar ppwwyyxx

augmentation

parent 87adcc46
...@@ -16,6 +16,7 @@ from tensorpack.utils.summary import * ...@@ -16,6 +16,7 @@ from tensorpack.utils.summary import *
from tensorpack.utils.callback import * from tensorpack.utils.callback import *
from tensorpack.utils.validation_callback import * from tensorpack.utils.validation_callback import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
BATCH_SIZE = 128 BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 500 MIN_AFTER_DEQUEUE = 500
...@@ -81,7 +82,15 @@ def get_config(): ...@@ -81,7 +82,15 @@ def get_config():
import cv2 import cv2
dataset_train = dataset.Cifar10('train') dataset_train = dataset.Cifar10('train')
dataset_train = MapData(dataset_train, lambda img: cv2.resize(img, (24, 24))) augmentor = imgaug.AugmentorList([
RandomCrop((24, 24)),
Flip(horiz=True),
BrightnessAdd(0.25),
Contrast((0.2,1.8)),
PerImageWhitening()
])
dataset_train = MapData(dataset_train, lambda img:
augmentor.augment(imgaug.Image(img)).arr)
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
dataset_test = dataset.Cifar10('test') dataset_test = dataset.Cifar10('test')
dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24))) dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24)))
......
...@@ -23,7 +23,6 @@ args = parser.parse_args() ...@@ -23,7 +23,6 @@ args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G: with tf.Graph().as_default() as G:
global_step_var = get_global_step_var()
config = get_config_func() config = get_config_func()
config.get_model_func(config.inputs, is_training=False) config.get_model_func(config.inputs, is_training=False)
init = sessinit.SaverRestore(args.model) init = sessinit.SaverRestore(args.model)
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dump_train_config.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import tensorflow as tf
import imp
import os
from tensorpack.utils.utils import mkdir_p
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument(dest='output')
parser.add_argument('-n', '--number', help='number of images to take',
default=10, type=int)
args = parser.parse_args()
mkdir_p(args.output)
index = 0 # TODO: as an argument?
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
cnt = 0
for dp in config.dataset.get_data():
imgbatch = dp[index]
if cnt > args.number:
break
for bi, img in enumerate(imgbatch):
cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img * 255.0)
...@@ -7,6 +7,7 @@ from pkgutil import walk_packages ...@@ -7,6 +7,7 @@ from pkgutil import walk_packages
import os import os
import os.path import os.path
import dataset import dataset
import imgaug
__SKIP = ['dftools', 'dataset'] __SKIP = ['dftools', 'dataset']
def global_import(name): def global_import(name):
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
from pkgutil import walk_packages
def global_import(name):
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import abstractmethod, ABCMeta
from ...utils import get_rng
__all__ = ['Image', 'ImageAugmentor', 'AugmentorList']
class Image(object):
""" An image with attributes, for augmentor to operate on
Attributes (such as coordinates) have to be augmented acoordingly, if necessary
"""
def __init__(self, arr, coords=None):
self.arr = arr
self.coords = coords
class ImageAugmentor(object):
__metaclass__ = ABCMeta
def __init__(self):
self.rng = get_rng(self)
def _init(self, params=None):
self.rng = get_rng(self)
if params:
for k, v in params.iteritems():
if k != 'self':
setattr(self, k, v)
def augment(self, img):
"""
Note: will both modify `img` in-place and return `img`
"""
self._augment(img)
return img
@abstractmethod
def _augment(self, img):
"""
Augment the image in-place. Will always make it float32 array.
Args:
img: the input Image instance
img.arr must be of shape [h, w] or [h, w, c]
"""
def _rand_range(self, low=1.0, high=None, size=None):
if high is None:
low, high = 0, low
if size == None:
size = []
return low + self.rng.rand(*size) * (high - low)
class AugmentorList(ImageAugmentor):
"""
Augment by a list of augmentors
"""
def __init__(self, augmentors):
self.augs = augmentors
def _augment(self, img):
img.arr = img.arr.astype('float32') / 255.0
for aug in self.augs:
aug.augment(img)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: crop.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor
__all__ = ['RandomCrop']
class RandomCrop(ImageAugmentor):
def __init__(self, crop_shape):
"""
Randomly crop the image into a smaller one
Args:
crop_shape: shape in (h, w)
"""
self._init(locals())
def _augment(self, img):
orig_shape = img.arr.shape
h0 = self.rng.randint(0, orig_shape[0] - self.crop_shape[0])
w0 = self.rng.randint(0, orig_shape[1] - self.crop_shape[1])
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
if img.coords:
raise NotImplementedError()
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: imgproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor
import numpy as np
__all__ = ['BrightnessAdd', 'Contrast', 'PerImageWhitening']
class BrightnessAdd(ImageAugmentor):
"""
Randomly add a value within [-delta,delta], and clip in [0,1]
"""
def __init__(self, delta):
assert delta > 0
self._init(locals())
def _augment(self, img):
v = self._rand_range(-self.delta, self.delta)
img.arr += v
img.arr = np.clip(img.arr, 0, 1)
class Contrast(ImageAugmentor):
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
"""
def __init__(self, factor_range):
self._init(locals())
def _augment(self, img):
arr = img.arr
r = self._rand_range(*self.factor_range)
mean = np.mean(arr, axis=(0,1), keepdims=True)
img.arr = (arr - mean) * r + mean
class PerImageWhitening(ImageAugmentor):
"""
Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels))
"""
def __init__(self):
pass
def _augment(self, img):
mean = np.mean(img.arr, axis=(0,1), keepdims=True)
std = np.std(img.arr, axis=(0,1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape[:2])))
img.arr = (img.arr - mean) / std
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: noname.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor
import numpy as np
import cv2
__all__ = ['Flip']
class Flip(ImageAugmentor):
def __init__(self, horiz=False, vert=False, prob=0.5):
"""
Random flip.
Args:
horiz, vert: True/False
"""
if horiz and vert:
self.code = -1
elif horiz:
self.code = 1
elif vert:
self.code = 0
else:
raise RuntimeError("Are you kidding?")
self.prob = prob
self._init()
def _augment(self, img):
if self._rand_range() < self.prob:
img.arr = cv2.flip(img.arr, self.code)
if img.coords:
raise NotImplementedError()
...@@ -9,6 +9,7 @@ import time ...@@ -9,6 +9,7 @@ import time
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
import numpy as np
import collections import collections
import logger import logger
...@@ -103,3 +104,6 @@ def get_global_step_var(): ...@@ -103,3 +104,6 @@ def get_global_step_var():
global_step_var = tf.Variable( global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
return global_step_var return global_step_var
def get_rng(self):
return np.random.RandomState()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment