Commit c44150d8 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 97cce4a2
......@@ -6,7 +6,6 @@ import tensorflow as tf
import argparse
import os
from tensorpack import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
......
......@@ -13,7 +13,13 @@ from ..utils.argtools import shape2d
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
def _valid_coords(coords):
def check_dtype(img):
if isinstance(img.dtype, np.integer):
assert img.dtype == np.uint8, \
"[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
def validate_coords(coords):
assert coords.ndim == 2, coords.ndim
assert coords.shape[1] == 2, coords.shape
assert np.issubdtype(coords.dtype, np.float), coords.dtype
......@@ -99,6 +105,7 @@ class AugmentImageComponent(MapDataComponent):
exception_handler = ExceptionHandler(catch_exceptions)
def func(x):
check_dtype(x)
with exception_handler.catch():
if copy:
x = copy_mod.deepcopy(x)
......@@ -138,7 +145,8 @@ class AugmentImageCoordinates(MapData):
def func(dp):
with exception_handler.catch():
img, coords = dp[img_index], dp[coords_index]
_valid_coords(coords)
check_dtype(img)
validate_coords(coords)
if copy:
img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img)
......@@ -191,14 +199,16 @@ class AugmentImageComponents(MapData):
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
with exception_handler.catch():
major_image = index[0] # image to be used to get params. TODO better design?
check_dtype(major_image)
im = copy_func(dp[major_image])
im, prms = self.augs._augment_return_params(im)
dp[major_image] = im
for idx in index[1:]:
check_dtype(dp[idx])
dp[idx] = self.augs._augment(copy_func(dp[idx]), prms)
for idx in coords_index:
coords = copy_func(dp[idx])
_valid_coords(coords)
validate_coords(coords)
dp[idx] = self.augs._augment_coords(coords, prms)
return dp
......
......@@ -5,10 +5,12 @@
import inspect
import pprint
from abc import abstractmethod, ABCMeta
from ...utils.utils import get_rng
import six
from six.moves import zip
from ...utils.utils import get_rng
from ..image import check_dtype
__all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
......@@ -101,6 +103,10 @@ class Augmentor(object):
class ImageAugmentor(Augmentor):
"""
ImageAugmentor should take images of type uint8 in range [0, 255], or
floating point images in range [0, 1] or [0, 255].
"""
def augment_coords(self, coords, param):
return self._augment_coords(coords, param)
......@@ -137,6 +143,7 @@ class AugmentorList(ImageAugmentor):
raise RuntimeError("Cannot simply get all parameters of a AugmentorList without running the augmentation!")
def _augment_return_params(self, img):
check_dtype(img)
assert img.ndim in [2, 3], img.ndim
prms = []
......@@ -146,6 +153,7 @@ class AugmentorList(ImageAugmentor):
return img, prms
def _augment(self, img, param):
check_dtype(img)
assert img.ndim in [2, 3], img.ndim
for aug, prm in zip(self.augs, param):
img = aug._augment(img, prm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment