Commit d46f15b5 authored by Yuxin Wu's avatar Yuxin Wu

add several augs

parent 4bc639da
Reproduce DQN in: Reproduce DQN in:
**Human-level Control Through Deep Reinforcement Learning** [Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
and Double-DQN in: and Double-DQN in:
**Deep Reinforcement Learning with Double Q-learning** [Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461)
Can reproduce the claimed performance, on several games I've tested with. Can reproduce the claimed performance, on several games I've tested with.
...@@ -21,9 +21,11 @@ To train: ...@@ -21,9 +21,11 @@ To train:
``` ```
./DQN.py --rom breakout.bin --gpu 0 ./DQN.py --rom breakout.bin --gpu 0
``` ```
Training speed is about 7.3 iteration/s on 1 Tesla M40. It takes days to learn well (see figure above). Training speed is about 7.3 iteration/s on 1 Tesla M40
(faster than this at the beginning, but will slow down due to exploration annealing).
It takes days to learn well (see figure above).
To play: To visualize the agent:
``` ```
./DQN.py --rom breakout.bin --task play --load pretrained.model ./DQN.py --rom breakout.bin --task play --load pretrained.model
``` ```
......
...@@ -77,7 +77,7 @@ def eval_on_ILSVRC12(model, sess_init, data_dir): ...@@ -77,7 +77,7 @@ def eval_on_ILSVRC12(model, sess_init, data_dir):
im = cv2.resize(im, tuple(desSize), interpolation=cv2.INTER_CUBIC) im = cv2.resize(im, tuple(desSize), interpolation=cv2.INTER_CUBIC)
return im return im
transformers = [ transformers = [
imgaug.AugmentWithFunc(resize_func), imgaug.MapImage(resize_func),
imgaug.CenterCrop((224, 224)), imgaug.CenterCrop((224, 224)),
] ]
ds = AugmentImageComponent(ds, transformers) ds = AugmentImageComponent(ds, transformers)
......
...@@ -6,7 +6,7 @@ from abc import abstractmethod, ABCMeta ...@@ -6,7 +6,7 @@ from abc import abstractmethod, ABCMeta
from ...utils import get_rng from ...utils import get_rng
from six.moves import zip from six.moves import zip
__all__ = ['ImageAugmentor', 'AugmentorList', 'AugmentWithFunc'] __all__ = ['ImageAugmentor', 'AugmentorList']
class ImageAugmentor(object): class ImageAugmentor(object):
""" Base class for an image augmentor""" """ Base class for an image augmentor"""
...@@ -63,14 +63,6 @@ class ImageAugmentor(object): ...@@ -63,14 +63,6 @@ class ImageAugmentor(object):
size = [] size = []
return low + self.rng.rand(*size) * (high - low) return low + self.rng.rand(*size) * (high - low)
class AugmentWithFunc(ImageAugmentor):
""" func: takes an image and return an image"""
def __init__(self, func):
self.func = func
def _augment(self, img, _):
return self.func(img)
class AugmentorList(ImageAugmentor): class AugmentorList(ImageAugmentor):
""" """
Augment by a list of augmentors Augment by a list of augmentors
...@@ -83,6 +75,7 @@ class AugmentorList(ImageAugmentor): ...@@ -83,6 +75,7 @@ class AugmentorList(ImageAugmentor):
super(AugmentorList, self).__init__() super(AugmentorList, self).__init__()
def _get_augment_params(self, img): def _get_augment_params(self, img):
# the next augmentor requires the previos one to finish
raise RuntimeError("Cannot simply get parameters of a AugmentorList!") raise RuntimeError("Cannot simply get parameters of a AugmentorList!")
def _augment_return_params(self, img): def _augment_return_params(self, img):
...@@ -107,4 +100,3 @@ class AugmentorList(ImageAugmentor): ...@@ -107,4 +100,3 @@ class AugmentorList(ImageAugmentor):
for a in self.augs: for a in self.augs:
a.reset_state() a.reset_state()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: geometry.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ImageAugmentor
import cv2
import numpy as np
__all__ = ['Rotation']
class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center"""
def __init__(self, max_deg, center_range=(0,1)):
"""
:param max_deg: max abs value of the rotation degree
:param center_range: the location of the rotation center
"""
self._init(locals())
def _get_augment_params(self, img):
center = img.shape[1::-1] * self._rand_range(
self.center_range[0], self.center_range[1], (2,))
deg = self._rand_range(-self.max_deg, self.max_deg)
return cv2.getRotationMatrix2D(tuple(center), deg, 1)
def _augment(self, img, rot_m):
return cv2.warpAffine(img rot_m, img.shape[1::-1],
flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .base import ImageAugmentor from .base import ImageAugmentor
import numpy as np import numpy as np
__all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize'] __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur', 'Gamma']
class Brightness(ImageAugmentor): class Brightness(ImageAugmentor):
""" """
...@@ -72,3 +72,31 @@ class MeanVarianceNormalize(ImageAugmentor): ...@@ -72,3 +72,31 @@ class MeanVarianceNormalize(ImageAugmentor):
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.shape))) std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.shape)))
img = (img - mean) / std img = (img - mean) / std
return img return img
class GaussianBlur(ImageAugmentor):
def __init__(self, max_size=3):
""":params max_size: (maximum kernel size-1)/2"""
self._init(locals())
def _get_augment_params(self, img):
sx, sy = self.rng.randint(self.max_size, size=(2,))
sx = sx * 2 + 1
sy = sy * 2 + 1
return sx, sy
def _augment(self, img, s):
return cv2.GaussianBlur(img, s, sigmaX=0, sigmaY=0,
borderType=cv2.BORDER_REPLICATE)
class Gamma(ImageAugmentor):
def __init__(self, range=(-0.5, 0.5)):
self._init(locals())
def _get_augment_params(self, _):
return self._rand_range(*self.range)
def _augment(self, img, gamma):
lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8')
cv2.LUT(img, lut, img)
return img
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: meta.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ImageAugmentor
__all__ = ['RandomChooseAug', 'MapImage']
class RandomChooseAug(ImageAugmentor):
def __init__(self, aug_lists):
"""
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
"""
if isinstance(aug_lists[0], (tuple, list)):
prob = [k[1] for k in aug_lists]
aug_lists = [k[0] for k in aug_lists]
self._init(locals())
else:
prob = 1.0 / len(aug_lists)
self._init(locals())
def _get_augment_params(self, img):
aug_idx = self.rng.choice(len(self.aug_lists), p=self.prob)
aug_prm = self.aug_lists[aug_idx]._get_augment_params(img)
return aug_idx, aug_prm
def _augment(self, img, prm):
idx, prm = prm
return self.aug_lists[idx]._augment(img, prm)
class MapImage(ImageAugmentor):
"""
Map the image array by a function.
"""
def __init__(self, func):
"""
:param func: a function which takes a image array and return a augmented one
"""
self.func = func
def _augment(self, img, _):
return self.func(img)
...@@ -6,7 +6,7 @@ from .base import ImageAugmentor ...@@ -6,7 +6,7 @@ from .base import ImageAugmentor
import numpy as np import numpy as np
import cv2 import cv2
__all__ = ['Flip', 'MapImage', 'Resize'] __all__ = ['Flip', 'Resize', 'RandomResize', 'JpegNoise']
class Flip(ImageAugmentor): class Flip(ImageAugmentor):
""" """
...@@ -43,20 +43,6 @@ class Flip(ImageAugmentor): ...@@ -43,20 +43,6 @@ class Flip(ImageAugmentor):
raise NotImplementedError() raise NotImplementedError()
class MapImage(ImageAugmentor):
"""
Map the image array by a function.
"""
def __init__(self, func):
"""
:param func: a function which takes a image array and return a augmented one
"""
self.func = func
def _augment(self, img, _):
return self.func(img)
class Resize(ImageAugmentor): class Resize(ImageAugmentor):
""" Resize image to a target size""" """ Resize image to a target size"""
def __init__(self, shape): def __init__(self, shape):
...@@ -69,3 +55,41 @@ class Resize(ImageAugmentor): ...@@ -69,3 +55,41 @@ class Resize(ImageAugmentor):
return cv2.resize( return cv2.resize(
img, self.shape[::-1], img, self.shape[::-1],
interpolation=cv2.INTER_CUBIC) interpolation=cv2.INTER_CUBIC)
class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image"""
def __init__(self, xrange, yrange, minimum=None, aspect_ratio_thres=0.2):
"""
:param xrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio
:param minimum: (xmin, ymin). Avoid scaling down too much.
:param aspect_ratio_thres: at most change k=20% aspect ratio
"""
self._init(locals())
def _get_augment_params(self, img):
while True:
sx = self._rand_range(*self.xrange)
sy = self._rand_range(*self.yrange)
destX = max(sx * img.shape[1], self.minimum[0])
destY = max(sy * img.shape[0], self.minimum[1])
oldr = img.shape[1] * 1.0 / img.shape[0]
newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr
if diff <= self.aspect_ratio_thres:
return (destX, destY)
def _augment(self, img, dsize):
return cv2.resize(img, dsize, interpolation=cv2.INTER_CUBIC)
class JpegNoise(ImageAugmentor):
def __init__(self, quality_range=(40, 100)):
self._init(locals())
def _get_augment_params(self, img):
return self._rand_range(*self.quality_range)
def _augment(self, img, q):
return cv2.imdecode(cv2.imencode('.jpg', img,
[cv2.IMWRITE_JPEG_QUALITY, q])[1], 1)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment