Commit 49b61fc0 authored by Yuxin Wu's avatar Yuxin Wu

make augmentors type-safe

parent 9b62b218
...@@ -38,7 +38,6 @@ MOCK_MODULES = ['scipy', ...@@ -38,7 +38,6 @@ MOCK_MODULES = ['scipy',
'sklearn', 'functools32'] 'sklearn', 'functools32']
for mod_name in MOCK_MODULES: for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name) sys.modules[mod_name] = mock.Mock(name=mod_name)
sys.modules['tensorflow'].__version__ = '0.12.0'
import tensorpack import tensorpack
...@@ -55,7 +54,7 @@ extensions = [ ...@@ -55,7 +54,7 @@ extensions = [
'sphinx.ext.todo', 'sphinx.ext.todo',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
#'sphinx.ext.autosectionlabel', #'sphinx.ext.autosectionlabel',
# 'sphinx.ext.coverage', #'sphinx.ext.coverage',
'sphinx.ext.mathjax', 'sphinx.ext.mathjax',
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
...@@ -104,9 +103,9 @@ author = u'Yuxin Wu' ...@@ -104,9 +103,9 @@ author = u'Yuxin Wu'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = u'0.2' version = tensorpack.__version__
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = u'0.2' release = version
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
......
...@@ -10,7 +10,7 @@ not necessarily the best for different scenarios. ...@@ -10,7 +10,7 @@ not necessarily the best for different scenarios.
### Use TensorFlow queues ### Use TensorFlow queues
In general, ``feed_dict`` is slow and should never appear in your critical loop. In general, `feed_dict` is slow and should never appear in your critical loop.
i.e., you should avoid loops like this: i.e., you should avoid loops like this:
```python ```python
while True: while True:
...@@ -31,10 +31,44 @@ while True: ...@@ -31,10 +31,44 @@ while True:
minimize_op.run() # minimize_op was built from dequeued tensors minimize_op.run() # minimize_op was built from dequeued tensors
``` ```
This is automatically handled by tensorpack trainers already (unless you used the demo ``SimpleTrainer``), This is now automatically handled by tensorpack trainers already (unless you used the demo ``SimpleTrainer``),
see [Trainer](trainer.md) for details. see [Trainer](trainer.md) for details.
TensorFlow is providing staging interface which may further improve the speed. This is TensorFlow is providing staging interface which may further improve the speed. This is
[issue#140](https://github.com/ppwwyyxx/tensorpack/issues/140). [issue#140](https://github.com/ppwwyyxx/tensorpack/issues/140).
You can also avoid `feed_dict` by using TensorFlow native operators to read data, which is also
supported here.
It probably allows you to reach the best performance, but at the cost of implementing the
reading / preprocessing ops in C++ if there isn't one for your task. We won't talk about it here.
### Figure out your bottleneck ### Figure out your bottleneck
For training we will only worry about the throughput but not the latency.
Thread 1 & 2 runs in parallel, and the faster one will block to wait for the slower one.
So the overall throughput will appear to be the slower one.
There isn't a way to accurately benchmark the two threads while they are running, without introducing overhead. But
there are ways to understand which one is the bottleneck:
1. Use the average occupancy (size) of the queue. This information is summarized after every epoch (TODO depend on #125).
If the queue is nearly empty, then the data thread is the bottleneck.
2. Benchmark them separately. You can use `TestDataSpeed` to benchmark a DataFlow, and
use `FakeData` as a fast replacement in a dry run to benchmark the training
iterations.
### Load ImageNet efficiently
We take ImageNet dataset as an example of how to optimize a DataFlow for speed.
We use ILSVRC12 training set, which contains 1.28 million images.
Following the [ResNet example](../examples/ResNet), our pre-processing need images in their original resolution, so we don't resize them.
The average resolution is about 400x350 <sup>[[1]]</sup>.
The original images (JPEG compressed) are 140G in total.
[1]: #ref
<div id=ref> </div>
[[1]]. [ImageNet: A Large-Scale Hierarchical Image Database](http://www.image-net.org/papers/imagenet_cvpr09.pdf), CVPR09
...@@ -5,9 +5,7 @@ import shutil ...@@ -5,9 +5,7 @@ import shutil
# setup metainfo # setup metainfo
CURRENT_DIR = os.path.dirname(__file__) CURRENT_DIR = os.path.dirname(__file__)
libinfo_py = os.path.join(CURRENT_DIR, 'tensorpack/libinfo.py') libinfo_py = os.path.join(CURRENT_DIR, 'tensorpack/libinfo.py')
libinfo = {'__file__': libinfo_py} exec(open(libinfo_py, "rb").read())
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
__version__ = libinfo['__version__']
# produce rst readme for pypi # produce rst readme for pypi
try: try:
......
...@@ -18,7 +18,7 @@ __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'Fixed ...@@ -18,7 +18,7 @@ __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'Fixed
class TestDataSpeed(ProxyDataFlow): class TestDataSpeed(ProxyDataFlow):
""" Test the speed of some DataFlow """ """ Test the speed of some DataFlow """
def __init__(self, ds, size=1000): def __init__(self, ds, size=5000):
""" """
Args: Args:
ds (DataFlow): the DataFlow to test. ds (DataFlow): the DataFlow to test.
...@@ -117,7 +117,7 @@ class BatchData(ProxyDataFlow): ...@@ -117,7 +117,7 @@ class BatchData(ProxyDataFlow):
tp = dt.dtype tp = dt.dtype
try: try:
result.append( result.append(
np.array([x[k] for x in data_holder], dtype=tp)) np.asarray([x[k] for x in data_holder], dtype=tp))
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except: except:
......
...@@ -99,7 +99,6 @@ class AugmentorList(ImageAugmentor): ...@@ -99,7 +99,6 @@ class AugmentorList(ImageAugmentor):
def _augment_return_params(self, img): def _augment_return_params(self, img):
assert img.ndim in [2, 3], img.ndim assert img.ndim in [2, 3], img.ndim
img = img.astype('float32')
prms = [] prms = []
for a in self.augs: for a in self.augs:
...@@ -109,7 +108,6 @@ class AugmentorList(ImageAugmentor): ...@@ -109,7 +108,6 @@ class AugmentorList(ImageAugmentor):
def _augment(self, img, param): def _augment(self, img, param):
assert img.ndim in [2, 3], img.ndim assert img.ndim in [2, 3], img.ndim
img = img.astype('float32')
for aug, prm in zip(self.augs, param): for aug, prm in zip(self.augs, param):
img = aug._augment(img, prm) img = aug._augment(img, prm)
return img return img
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: convert.py
from .base import ImageAugmentor
from .meta import MapImage
import numpy as np
import cv2
__all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
class ColorSpace(ImageAugmentor):
""" Convert into another colorspace. """
def __init__(self, mode=cv2.COLOR_BGR2GRAY, keepdims=True):
"""
Args:
mode: opencv colorspace conversion code (e.g., `cv2.COLOR_BGR2HSV`)
keepdims (bool): keep the dimension of image unchanged if opencv
changes it.
"""
self._init(locals())
def _augment(self, img, _):
transf = cv2.cvtColor(img, self.mode)
if self.keepdims:
if len(transf.shape) is not len(img.shape):
transf = transf[..., None]
return transf
class Grayscale(ColorSpace):
""" Convert image to grayscale. """
def __init__(self, keepdims=True, rgb=False):
"""
Args:
keepdims (bool): return image of shape [H, W, 1] instead of [H, W]
rgb (bool): interpret input as RGB instead of the default BGR
"""
mode = cv2.COLOR_RGB2GRAY if rgb else cv2.COLOR_BGR2GRAY
super(Grayscale, self).__init__(mode, keepdims)
class ToUint8(MapImage):
""" Convert image to uint8. Useful to reduce communication overhead. """
def __init__(self):
super(ToUint8, self).__init__(lambda x: np.clip(x, 0, 255).astype(np.uint8))
class ToFloat32(MapImage):
""" Convert image to float32, may increase quality of the augmentor. """
def __init__(self):
super(ToFloat32, self).__init__(lambda x: x.astype(np.float32))
...@@ -6,46 +6,10 @@ from .base import ImageAugmentor ...@@ -6,46 +6,10 @@ from .base import ImageAugmentor
import numpy as np import numpy as np
import cv2 import cv2
__all__ = ['ColorSpace', 'Hue', 'Grayscale', 'Brightness', 'Contrast', 'MeanVarianceNormalize', __all__ = ['Hue', 'Brightness', 'Contrast', 'MeanVarianceNormalize',
'GaussianBlur', 'Gamma', 'Clip', 'Saturation', 'Lighting'] 'GaussianBlur', 'Gamma', 'Clip', 'Saturation', 'Lighting']
class ColorSpace(ImageAugmentor):
"""
Convert into another colorspace.
"""
def __init__(self, mode=cv2.COLOR_BGR2GRAY, keepdims=True):
"""
Args:
mode: opencv colorspace conversion code (e.g., `cv2.COLOR_BGR2HSV`)
keepdims (bool): keep the dimension of image unchanged if opencv
changes it.
"""
self._init(locals())
def _augment(self, img, _):
transf = cv2.cvtColor(img, self.mode)
if self.keepdims:
if len(transf.shape) is not len(img.shape):
transf = transf[..., None]
return transf
class Grayscale(ColorSpace):
"""
Convert image to grayscale.
"""
def __init__(self, keepdims=True, rgb=False):
"""
Args:
keepdims (bool): return image of shape [H, W, 1] instead of [H, W]
rgb (bool): interpret input as RGB instead of the default BGR
"""
mode = cv2.COLOR_RGB2GRAY if rgb else cv2.COLOR_BGR2GRAY
super(Grayscale, self).__init__(mode, keepdims)
class Hue(ImageAugmentor): class Hue(ImageAugmentor):
""" Randomly change color hue of a BGR input. """ Randomly change color hue of a BGR input.
""" """
...@@ -85,10 +49,12 @@ class Brightness(ImageAugmentor): ...@@ -85,10 +49,12 @@ class Brightness(ImageAugmentor):
return v return v
def _augment(self, img, v): def _augment(self, img, v):
old_dtype = img.dtype
img = img.astype('float32')
img += v img += v
if self.clip: if self.clip or old_dtype == np.uint8:
img = np.clip(img, 0, 255) img = np.clip(img, 0, 255)
return img return img.asypte(old_dtype)
class Contrast(ImageAugmentor): class Contrast(ImageAugmentor):
...@@ -109,18 +75,23 @@ class Contrast(ImageAugmentor): ...@@ -109,18 +75,23 @@ class Contrast(ImageAugmentor):
return self._rand_range(*self.factor_range) return self._rand_range(*self.factor_range)
def _augment(self, img, r): def _augment(self, img, r):
old_dtype = img.dtype
img = img.astype('float32')
mean = np.mean(img, axis=(0, 1), keepdims=True) mean = np.mean(img, axis=(0, 1), keepdims=True)
img = (img - mean) * r + mean img = (img - mean) * r + mean
if self.clip: if self.clip or old_dtype == np.uint8:
img = np.clip(img, 0, 255) img = np.clip(img, 0, 255)
return img return img.astype(old_dtype)
class MeanVarianceNormalize(ImageAugmentor): class MeanVarianceNormalize(ImageAugmentor):
""" """
Linearly scales the image to have zero mean and unit norm. Linearly scales the image to have zero mean and unit norm.
``x = (x - mean) / adjusted_stddev`` ``x = (x - mean) / adjusted_stddev``
where ``adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))`` where ``adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
This augmentor always returns float32 images.
``
""" """
def __init__(self, all_channel=True): def __init__(self, all_channel=True):
...@@ -131,6 +102,7 @@ class MeanVarianceNormalize(ImageAugmentor): ...@@ -131,6 +102,7 @@ class MeanVarianceNormalize(ImageAugmentor):
self.all_channel = all_channel self.all_channel = all_channel
def _augment(self, img, _): def _augment(self, img, _):
img = img.astype('float32')
if self.all_channel: if self.all_channel:
mean = np.mean(img) mean = np.mean(img)
std = np.std(img) std = np.std(img)
...@@ -178,9 +150,10 @@ class Gamma(ImageAugmentor): ...@@ -178,9 +150,10 @@ class Gamma(ImageAugmentor):
return self._rand_range(*self.range) return self._rand_range(*self.range)
def _augment(self, img, gamma): def _augment(self, img, gamma):
old_dtype = img.dtype
lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8') lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8')
img = np.clip(img, 0, 255).astype('uint8') img = np.clip(img, 0, 255).astype('uint8')
img = cv2.LUT(img, lut).astype('float32') img = cv2.LUT(img, lut).astype(old_dtype)
return img return img
...@@ -218,8 +191,10 @@ class Saturation(ImageAugmentor): ...@@ -218,8 +191,10 @@ class Saturation(ImageAugmentor):
return 1 + self._rand_range(-self.alpha, self.alpha) return 1 + self._rand_range(-self.alpha, self.alpha)
def _augment(self, img, v): def _augment(self, img, v):
old_dtype = img.dtype
grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img * v + (grey * (1 - v))[:, :, np.newaxis] ret = img * v + (grey * (1 - v))[:, :, np.newaxis]
return ret.astype(old_dtype)
class Lighting(ImageAugmentor): class Lighting(ImageAugmentor):
...@@ -248,8 +223,9 @@ class Lighting(ImageAugmentor): ...@@ -248,8 +223,9 @@ class Lighting(ImageAugmentor):
return self.rng.randn(3) * self.std return self.rng.randn(3) * self.std
def _augment(self, img, v): def _augment(self, img, v):
old_dtype = img.dtype
v = v * self.eigval v = v * self.eigval
v = v.reshape((3, 1)) v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,)) inc = np.dot(self.eigvec, v).reshape((3,))
img += inc img += inc
return img return img.astype(old_dtype)
...@@ -26,7 +26,7 @@ class JpegNoise(ImageAugmentor): ...@@ -26,7 +26,7 @@ class JpegNoise(ImageAugmentor):
def _augment(self, img, q): def _augment(self, img, q):
enc = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, q])[1] enc = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, q])[1]
return cv2.imdecode(enc, 1) return cv2.imdecode(enc, 1).astype(img.dtype)
class GaussianNoise(ImageAugmentor): class GaussianNoise(ImageAugmentor):
...@@ -46,10 +46,11 @@ class GaussianNoise(ImageAugmentor): ...@@ -46,10 +46,11 @@ class GaussianNoise(ImageAugmentor):
return self.rng.randn(*img.shape) return self.rng.randn(*img.shape)
def _augment(self, img, noise): def _augment(self, img, noise):
old_dtype = img.dtype
ret = img + noise * self.sigma ret = img + noise * self.sigma
if self.clip: if self.clip or old_dtype == np.uint8:
ret = np.clip(ret, 0, 255) ret = np.clip(ret, 0, 255)
return ret return ret.astype(old_dtype)
class SaltPepperNoise(ImageAugmentor): class SaltPepperNoise(ImageAugmentor):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment