Commit 2a444073 authored by Yuxin Wu's avatar Yuxin Wu

more augmentor

parent 0266827f
...@@ -57,7 +57,7 @@ class Model(ModelDesc): ...@@ -57,7 +57,7 @@ class Model(ModelDesc):
outs.append(x4) outs.append(x4)
return tf.concat(3, outs, name='concat') return tf.concat(3, outs, name='concat')
with argscope(Conv2D, nl=BNReLU(), use_bias=False): with argscope(Conv2D, nl=BNReLU, use_bias=False):
l = Conv2D('conv0', image, 64, 7, stride=2) l = Conv2D('conv0', image, 64, 7, stride=2)
l = MaxPooling('pool0', l, 3, 2, padding='SAME') l = MaxPooling('pool0', l, 3, 2, padding='SAME')
l = Conv2D('conv1', l, 64, 1) l = Conv2D('conv1', l, 64, 1)
......
...@@ -36,7 +36,7 @@ class Model(ModelDesc): ...@@ -36,7 +36,7 @@ class Model(ModelDesc):
def _build_graph(self, input_vars): def _build_graph(self, input_vars):
image, label = input_vars image, label = input_vars
image = image / 128.0 - 1 # ? image = image / 255.0 # ?
def proj_kk(l, k, ch_r, ch, stride=1): def proj_kk(l, k, ch_r, ch, stride=1):
l = Conv2D('conv{0}{0}r'.format(k), l, ch_r, 1) l = Conv2D('conv{0}{0}r'.format(k), l, ch_r, 1)
...@@ -70,8 +70,8 @@ class Model(ModelDesc): ...@@ -70,8 +70,8 @@ class Model(ModelDesc):
.Conv2D('conv277ba', ch_r, [7,1]) .Conv2D('conv277ba', ch_r, [7,1])
.Conv2D('conv277bb', ch, [1,7])()) .Conv2D('conv277bb', ch, [1,7])())
nl = BNReLU(decay=0.9997, epsilon=1e-3) with argscope(Conv2D, nl=BNReLU, use_bias=False),\
with argscope(Conv2D, nl=nl, use_bias=False): argscope(BatchNorm, decay=0.9997, epsilon=1e-3):
l = (LinearWrap(image) l = (LinearWrap(image)
.Conv2D('conv0', 32, 3, stride=2, padding='VALID') #299 .Conv2D('conv0', 32, 3, stride=2, padding='VALID') #299
.Conv2D('conv1', 32, 3, padding='VALID') #149 .Conv2D('conv1', 32, 3, padding='VALID') #149
...@@ -269,8 +269,8 @@ def get_config(): ...@@ -269,8 +269,8 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
InferenceRunner(dataset_val, [ InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-top1-error'), ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-top5-error')]), ClassificationError('wrong-top5', 'val-error-top5')]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(5, 0.03), (9, 0.01), (12, 0.006), [(5, 0.03), (9, 0.01), (12, 0.006),
(17, 0.003), (22, 1e-3), (36, 2e-4), (17, 0.003), (22, 1e-3), (36, 2e-4),
......
...@@ -41,7 +41,7 @@ class Model(ModelDesc): ...@@ -41,7 +41,7 @@ class Model(ModelDesc):
tf.image_summary("train_image", image, 10) tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller image = image / 4.0 # just to make range smaller
with argscope(Conv2D, nl=BNReLU(), use_bias=False, kernel_shape=3): with argscope(Conv2D, nl=BNReLU, use_bias=False, kernel_shape=3):
logits = LinearWrap(image) \ logits = LinearWrap(image) \
.Conv2D('conv1.1', out_channel=64) \ .Conv2D('conv1.1', out_channel=64) \
.Conv2D('conv1.2', out_channel=64) \ .Conv2D('conv1.2', out_channel=64) \
......
...@@ -156,7 +156,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -156,7 +156,7 @@ class ILSVRC12(RNGDataFlow):
def get_data(self): def get_data(self):
""" """
Produce original images of shape [h, w, 3], and label, Produce original images of shape [h, w, 3(BGR)], and label,
and optionally a bbox of [xmin, ymin, xmax, ymax] and optionally a bbox of [xmin, ymin, xmax, ymax]
""" """
idxs = np.arange(len(self.imglist)) idxs = np.arange(len(self.imglist))
......
...@@ -6,10 +6,10 @@ from abc import abstractmethod, ABCMeta ...@@ -6,10 +6,10 @@ from abc import abstractmethod, ABCMeta
from ...utils import get_rng from ...utils import get_rng
from six.moves import zip from six.moves import zip
__all__ = ['ImageAugmentor', 'AugmentorList'] __all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
class ImageAugmentor(object): class Augmentor(object):
""" Base class for an image augmentor""" """ Base class for an augmentor"""
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self): def __init__(self):
...@@ -24,37 +24,32 @@ class ImageAugmentor(object): ...@@ -24,37 +24,32 @@ class ImageAugmentor(object):
def reset_state(self): def reset_state(self):
self.rng = get_rng(self) self.rng = get_rng(self)
def augment(self, img): def augment(self, d):
""" """
Perform augmentation on the image in-place. Perform augmentation on the data.
:param img: an [h,w] or [h,w,c] image
:returns: the augmented image, always of type 'float32'
""" """
img, params = self._augment_return_params(img) d, params = self._augment_return_params(d)
return img return d
def _augment_return_params(self, img): def _augment_return_params(self, d):
""" """
Augment the image and return both image and params Augment the image and return both image and params
""" """
prms = self._get_augment_params(img) prms = self._get_augment_params(d)
return (self._augment(img, prms), prms) return (self._augment(d, prms), prms)
@abstractmethod @abstractmethod
def _augment(self, img, param): def _augment(self, d, param):
""" """
augment with the given param and return the new image augment with the given param and return the new image
""" """
def _get_augment_params(self, img): def _get_augment_params(self, d):
""" """
get the augmentor parameters get the augmentor parameters
""" """
return None return None
def _fprop_coord(self, coord, param):
return coord
def _rand_range(self, low=1.0, high=None, size=None): def _rand_range(self, low=1.0, high=None, size=None):
if high is None: if high is None:
low, high = 0, low low, high = 0, low
...@@ -62,6 +57,19 @@ class ImageAugmentor(object): ...@@ -62,6 +57,19 @@ class ImageAugmentor(object):
size = [] size = []
return self.rng.uniform(low, high, size) return self.rng.uniform(low, high, size)
class ImageAugmentor(Augmentor):
def augment(self, img):
"""
Perform augmentation on the image in-place.
:param img: an [h,w] or [h,w,c] image
:returns: the augmented image, always of type 'float32'
"""
img, params = self._augment_return_params(img)
return img
def _fprop_coord(self, coord, param):
return coord
class AugmentorList(ImageAugmentor): class AugmentorList(ImageAugmentor):
""" """
Augment by a list of augmentors Augment by a list of augmentors
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import cv2 import cv2
__all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur', __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur',
'Gamma', 'Clip'] 'Gamma', 'Clip', 'Saturation']
class Brightness(ImageAugmentor): class Brightness(ImageAugmentor):
""" """
...@@ -111,9 +111,22 @@ class Gamma(ImageAugmentor): ...@@ -111,9 +111,22 @@ class Gamma(ImageAugmentor):
class Clip(ImageAugmentor): class Clip(ImageAugmentor):
def __init__(self, min=0, max=255): def __init__(self, min=0, max=255):
assert delta > 0
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
img = np.clip(img, self.min, self.max) img = np.clip(img, self.min, self.max)
return img return img
class Saturation(ImageAugmentor):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218"""
def __init__(self, alpha=0.4):
super(Saturation, self).__init__()
assert alpha < 1
self._init(locals())
def _get_augment_params(self, _):
return 1 + self._rand_range(-self.alpha, self.alpha)
def _augment(self, img, v):
grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img * v + (grey * (1 - v))[:,:,np.newaxis]
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
from .base import ImageAugmentor from .base import ImageAugmentor
__all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug'] __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
'RandomOrderAug']
class Identity(ImageAugmentor): class Identity(ImageAugmentor):
def _augment(self, img, _): def _augment(self, img, _):
...@@ -15,8 +16,8 @@ class Identity(ImageAugmentor): ...@@ -15,8 +16,8 @@ class Identity(ImageAugmentor):
class RandomApplyAug(ImageAugmentor): class RandomApplyAug(ImageAugmentor):
""" Randomly apply the augmentor with a prob. Otherwise do nothing""" """ Randomly apply the augmentor with a prob. Otherwise do nothing"""
def __init__(self, aug, prob): def __init__(self, aug, prob):
super(RandomApplyAug, self).__init__()
self._init(locals()) self._init(locals())
super(RandomApplyAug, self).__init__()
def _get_augment_params(self, img): def _get_augment_params(self, img):
p = self.rng.rand() p = self.rng.rand()
...@@ -41,7 +42,6 @@ class RandomChooseAug(ImageAugmentor): ...@@ -41,7 +42,6 @@ class RandomChooseAug(ImageAugmentor):
""" """
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple :param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
""" """
super(RandomChooseAug, self).__init__()
if isinstance(aug_lists[0], (tuple, list)): if isinstance(aug_lists[0], (tuple, list)):
prob = [k[1] for k in aug_lists] prob = [k[1] for k in aug_lists]
aug_lists = [k[0] for k in aug_lists] aug_lists = [k[0] for k in aug_lists]
...@@ -49,6 +49,7 @@ class RandomChooseAug(ImageAugmentor): ...@@ -49,6 +49,7 @@ class RandomChooseAug(ImageAugmentor):
else: else:
prob = 1.0 / len(aug_lists) prob = 1.0 / len(aug_lists)
self._init(locals()) self._init(locals())
super(RandomChooseAug, self).__init__()
def reset_state(self): def reset_state(self):
super(RandomChooseAug, self).reset_state() super(RandomChooseAug, self).reset_state()
...@@ -64,6 +65,34 @@ class RandomChooseAug(ImageAugmentor): ...@@ -64,6 +65,34 @@ class RandomChooseAug(ImageAugmentor):
idx, prm = prm idx, prm = prm
return self.aug_lists[idx]._augment(img, prm) return self.aug_lists[idx]._augment(img, prm)
class RandomOrderAug(ImageAugmentor):
def __init__(self, aug_lists):
"""
Shuffle the augmentors into random order.
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
"""
self._init(locals())
super(RandomOrderAug, self).__init__()
def reset_state(self):
super(RandomOrderAug, self).reset_state()
for a in self.aug_lists:
a.reset_state()
def _get_augment_params(self, img):
# Note: If augmentors change the shape of image, get_augment_param might not work
# All augmentors should only rely on the shape of image
idxs = self.rng.permutation(len(self.aug_lists))
prms = [self.aug_lists[k]._get_augment_params(img)
for k in range(len(self.aug_lists))]
return idxs, prms
def _augment(self, img, prm):
idxs, prms = prm
for k in idxs:
img = self.aug_lists[k]._augment(img, prms[k])
return img
class MapImage(ImageAugmentor): class MapImage(ImageAugmentor):
""" """
Map the image array by a function. Map the image array by a function.
......
...@@ -63,15 +63,7 @@ def LeakyReLU(x, alpha, name=None): ...@@ -63,15 +63,7 @@ def LeakyReLU(x, alpha, name=None):
name = 'output' name = 'output'
return tf.mul(x, 0.5, name=name) return tf.mul(x, 0.5, name=name)
# I'm not a layer, but I return a nonlinearity. def BNReLU(x, name=None):
def BNReLU(is_training=None, **kwargs): x = BatchNorm('bn', x, use_local_stat=None)
""" x = tf.nn.relu(x, name=name)
:param is_traning: boolean return x
:param kwargs: args for BatchNorm
:returns: an activation function that performs BN + ReLU (a too common combination)
"""
def BNReLU(x, name=None):
x = BatchNorm('bn', x, use_local_stat=is_training, **kwargs)
x = tf.nn.relu(x, name=name)
return x
return BNReLU
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment