Commit c44150d8 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 97cce4a2
...@@ -6,7 +6,6 @@ import tensorflow as tf ...@@ -6,7 +6,6 @@ import tensorflow as tf
import argparse import argparse
import os import os
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
......
...@@ -13,7 +13,13 @@ from ..utils.argtools import shape2d ...@@ -13,7 +13,13 @@ from ..utils.argtools import shape2d
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
def _valid_coords(coords): def check_dtype(img):
if isinstance(img.dtype, np.integer):
assert img.dtype == np.uint8, \
"[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
def validate_coords(coords):
assert coords.ndim == 2, coords.ndim assert coords.ndim == 2, coords.ndim
assert coords.shape[1] == 2, coords.shape assert coords.shape[1] == 2, coords.shape
assert np.issubdtype(coords.dtype, np.float), coords.dtype assert np.issubdtype(coords.dtype, np.float), coords.dtype
...@@ -99,6 +105,7 @@ class AugmentImageComponent(MapDataComponent): ...@@ -99,6 +105,7 @@ class AugmentImageComponent(MapDataComponent):
exception_handler = ExceptionHandler(catch_exceptions) exception_handler = ExceptionHandler(catch_exceptions)
def func(x): def func(x):
check_dtype(x)
with exception_handler.catch(): with exception_handler.catch():
if copy: if copy:
x = copy_mod.deepcopy(x) x = copy_mod.deepcopy(x)
...@@ -138,7 +145,8 @@ class AugmentImageCoordinates(MapData): ...@@ -138,7 +145,8 @@ class AugmentImageCoordinates(MapData):
def func(dp): def func(dp):
with exception_handler.catch(): with exception_handler.catch():
img, coords = dp[img_index], dp[coords_index] img, coords = dp[img_index], dp[coords_index]
_valid_coords(coords) check_dtype(img)
validate_coords(coords)
if copy: if copy:
img, coords = copy_mod.deepcopy((img, coords)) img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img) img, prms = self.augs._augment_return_params(img)
...@@ -191,14 +199,16 @@ class AugmentImageComponents(MapData): ...@@ -191,14 +199,16 @@ class AugmentImageComponents(MapData):
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
with exception_handler.catch(): with exception_handler.catch():
major_image = index[0] # image to be used to get params. TODO better design? major_image = index[0] # image to be used to get params. TODO better design?
check_dtype(major_image)
im = copy_func(dp[major_image]) im = copy_func(dp[major_image])
im, prms = self.augs._augment_return_params(im) im, prms = self.augs._augment_return_params(im)
dp[major_image] = im dp[major_image] = im
for idx in index[1:]: for idx in index[1:]:
check_dtype(dp[idx])
dp[idx] = self.augs._augment(copy_func(dp[idx]), prms) dp[idx] = self.augs._augment(copy_func(dp[idx]), prms)
for idx in coords_index: for idx in coords_index:
coords = copy_func(dp[idx]) coords = copy_func(dp[idx])
_valid_coords(coords) validate_coords(coords)
dp[idx] = self.augs._augment_coords(coords, prms) dp[idx] = self.augs._augment_coords(coords, prms)
return dp return dp
......
...@@ -5,10 +5,12 @@ ...@@ -5,10 +5,12 @@
import inspect import inspect
import pprint import pprint
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ...utils.utils import get_rng
import six import six
from six.moves import zip from six.moves import zip
from ...utils.utils import get_rng
from ..image import check_dtype
__all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList'] __all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
...@@ -101,6 +103,10 @@ class Augmentor(object): ...@@ -101,6 +103,10 @@ class Augmentor(object):
class ImageAugmentor(Augmentor): class ImageAugmentor(Augmentor):
"""
ImageAugmentor should take images of type uint8 in range [0, 255], or
floating point images in range [0, 1] or [0, 255].
"""
def augment_coords(self, coords, param): def augment_coords(self, coords, param):
return self._augment_coords(coords, param) return self._augment_coords(coords, param)
...@@ -137,6 +143,7 @@ class AugmentorList(ImageAugmentor): ...@@ -137,6 +143,7 @@ class AugmentorList(ImageAugmentor):
raise RuntimeError("Cannot simply get all parameters of a AugmentorList without running the augmentation!") raise RuntimeError("Cannot simply get all parameters of a AugmentorList without running the augmentation!")
def _augment_return_params(self, img): def _augment_return_params(self, img):
check_dtype(img)
assert img.ndim in [2, 3], img.ndim assert img.ndim in [2, 3], img.ndim
prms = [] prms = []
...@@ -146,6 +153,7 @@ class AugmentorList(ImageAugmentor): ...@@ -146,6 +153,7 @@ class AugmentorList(ImageAugmentor):
return img, prms return img, prms
def _augment(self, img, param): def _augment(self, img, param):
check_dtype(img)
assert img.ndim in [2, 3], img.ndim assert img.ndim in [2, 3], img.ndim
for aug, prm in zip(self.augs, param): for aug, prm in zip(self.augs, param):
img = aug._augment(img, prm) img = aug._augment(img, prm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment