Commit a0601fb7 authored by eyaler's avatar eyaler Committed by Yuxin Wu

enable image exceptions and make this default behavior (#490)

* enable image exceptions and make this default behavior

* change allow_exceptions to catch_exceptions; use contextmanager

* fix linter and update docs
parent a347aff8
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import numpy as np import numpy as np
import copy as copy_mod import copy as copy_mod
from contextlib import contextmanager
from .base import RNGDataFlow from .base import RNGDataFlow
from .common import MapDataComponent, MapData from .common import MapDataComponent, MapData
from ..utils import logger from ..utils import logger
...@@ -18,6 +19,26 @@ def _valid_coords(coords): ...@@ -18,6 +19,26 @@ def _valid_coords(coords):
assert np.issubdtype(coords.dtype, np.float), coords.dtype assert np.issubdtype(coords.dtype, np.float), coords.dtype
class ExceptionHandler:
def __init__(self, catch_exceptions=False):
self._nr_error = 0
self.catch_exceptions = catch_exceptions
@contextmanager
def catch(self):
try:
yield
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if not self.catch_exceptions:
raise
else:
if self._nr_error % 100 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
class ImageFromFile(RNGDataFlow): class ImageFromFile(RNGDataFlow):
""" Produce images read from a list of files. """ """ Produce images read from a list of files. """
def __init__(self, files, channel=3, resize=None, shuffle=False): def __init__(self, files, channel=3, resize=None, shuffle=False):
...@@ -57,7 +78,8 @@ class AugmentImageComponent(MapDataComponent): ...@@ -57,7 +78,8 @@ class AugmentImageComponent(MapDataComponent):
""" """
Apply image augmentors on 1 image component. Apply image augmentors on 1 image component.
""" """
def __init__(self, ds, augmentors, index=0, copy=True):
def __init__(self, ds, augmentors, index=0, copy=True, catch_exceptions=False):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
...@@ -67,27 +89,22 @@ class AugmentImageComponent(MapDataComponent): ...@@ -67,27 +89,22 @@ class AugmentImageComponent(MapDataComponent):
True, a copy will be made before any augmentors are applied, True, a copy will be made before any augmentors are applied,
to keep the original images not modified. to keep the original images not modified.
Turn it off to save time when you know it's OK. Turn it off to save time when you know it's OK.
catch_exceptions (bool): when set to True, will catch
all exceptions and only warn you when there are too many (>100).
Can be used to ignore occasion errors in data.
""" """
if isinstance(augmentors, AugmentorList): if isinstance(augmentors, AugmentorList):
self.augs = augmentors self.augs = augmentors
else: else:
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self._nr_error = 0 exception_handler = ExceptionHandler(catch_exceptions)
def func(x): def func(x):
try: with exception_handler.catch():
if copy: if copy:
x = copy_mod.deepcopy(x) x = copy_mod.deepcopy(x)
ret = self.augs.augment(x) return self.augs.augment(x)
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
return ret
super(AugmentImageComponent, self).__init__( super(AugmentImageComponent, self).__init__(
ds, func, index) ds, func, index)
...@@ -102,26 +119,26 @@ class AugmentImageCoordinates(MapData): ...@@ -102,26 +119,26 @@ class AugmentImageCoordinates(MapData):
Apply image augmentors on an image and a list of coordinates. Apply image augmentors on an image and a list of coordinates.
Coordinates must be a Nx2 floating point array, each row is (x, y). Coordinates must be a Nx2 floating point array, each row is (x, y).
""" """
def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True):
def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True, catch_exceptions=False):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int): the index of the image component to be augmented. img_index (int): the index of the image component to be augmented.
coords_index (int): the index of the coordinate component to be augmented. coords_index (int): the index of the coordinate component to be augmented.
copy (bool): Some augmentors modify the input images. When copy is copy, catch_exceptions: same as in :class:`AugmentImageComponent`
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
""" """
if isinstance(augmentors, AugmentorList): if isinstance(augmentors, AugmentorList):
self.augs = augmentors self.augs = augmentors
else: else:
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self._nr_error = 0
exception_handler = ExceptionHandler(catch_exceptions)
def func(dp): def func(dp):
try: with exception_handler.catch():
img, coords = dp[img_index], dp[coords_index] img, coords = dp[img_index], dp[coords_index]
_valid_coords(coords) _valid_coords(coords)
if copy: if copy:
...@@ -131,13 +148,6 @@ class AugmentImageCoordinates(MapData): ...@@ -131,13 +148,6 @@ class AugmentImageCoordinates(MapData):
coords = self.augs._augment_coords(coords, prms) coords = self.augs._augment_coords(coords, prms)
dp[coords_index] = coords dp[coords_index] = coords
return dp return dp
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
super(AugmentImageCoordinates, self).__init__(ds, func) super(AugmentImageCoordinates, self).__init__(ds, func)
...@@ -161,29 +171,27 @@ class AugmentImageComponents(MapData): ...@@ -161,29 +171,27 @@ class AugmentImageComponents(MapData):
""" """
def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True): def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True, catch_exceptions=False):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of the image components. index: tuple of indices of the image components.
coords_index: tuple of indices of the coordinates components. coords_index: tuple of indices of the coordinates components.
copy (bool): Some augmentors modify the input images. When copy is copy, catch_exceptions: same as in :class:`AugmentImageComponent`
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
""" """
if isinstance(augmentors, AugmentorList): if isinstance(augmentors, AugmentorList):
self.augs = augmentors self.augs = augmentors
else: else:
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self.ds = ds self.ds = ds
self._nr_error = 0
exception_handler = ExceptionHandler(catch_exceptions)
def func(dp): def func(dp):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
try: with exception_handler.catch():
major_image = index[0] # image to be used to get params. TODO better design? major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image]) im = copy_func(dp[major_image])
im, prms = self.augs._augment_return_params(im) im, prms = self.augs._augment_return_params(im)
...@@ -195,13 +203,6 @@ class AugmentImageComponents(MapData): ...@@ -195,13 +203,6 @@ class AugmentImageComponents(MapData):
_valid_coords(coords) _valid_coords(coords)
dp[idx] = self.augs._augment_coords(coords, prms) dp[idx] = self.augs._augment_coords(coords, prms)
return dp return dp
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
super(AugmentImageComponents, self).__init__(ds, func) super(AugmentImageComponents, self).__init__(ds, func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment