Commit a0601fb7 authored by eyaler's avatar eyaler Committed by Yuxin Wu

enable image exceptions and make this default behavior (#490)

* enable image exceptions and make this default behavior

* change allow_exceptions to catch_exceptions; use contextmanager

* fix linter and update docs
parent a347aff8
......@@ -4,6 +4,7 @@
import numpy as np
import copy as copy_mod
from contextlib import contextmanager
from .base import RNGDataFlow
from .common import MapDataComponent, MapData
from ..utils import logger
......@@ -18,6 +19,26 @@ def _valid_coords(coords):
assert np.issubdtype(coords.dtype, np.float), coords.dtype
class ExceptionHandler:
def __init__(self, catch_exceptions=False):
self._nr_error = 0
self.catch_exceptions = catch_exceptions
@contextmanager
def catch(self):
try:
yield
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if not self.catch_exceptions:
raise
else:
if self._nr_error % 100 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
class ImageFromFile(RNGDataFlow):
""" Produce images read from a list of files. """
def __init__(self, files, channel=3, resize=None, shuffle=False):
......@@ -57,7 +78,8 @@ class AugmentImageComponent(MapDataComponent):
"""
Apply image augmentors on 1 image component.
"""
def __init__(self, ds, augmentors, index=0, copy=True):
def __init__(self, ds, augmentors, index=0, copy=True, catch_exceptions=False):
"""
Args:
ds (DataFlow): input DataFlow.
......@@ -67,27 +89,22 @@ class AugmentImageComponent(MapDataComponent):
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
catch_exceptions (bool): when set to True, will catch
all exceptions and only warn you when there are too many (>100).
Can be used to ignore occasion errors in data.
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self._nr_error = 0
exception_handler = ExceptionHandler(catch_exceptions)
def func(x):
try:
with exception_handler.catch():
if copy:
x = copy_mod.deepcopy(x)
ret = self.augs.augment(x)
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
return ret
return self.augs.augment(x)
super(AugmentImageComponent, self).__init__(
ds, func, index)
......@@ -102,26 +119,26 @@ class AugmentImageCoordinates(MapData):
Apply image augmentors on an image and a list of coordinates.
Coordinates must be a Nx2 floating point array, each row is (x, y).
"""
def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True):
def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True, catch_exceptions=False):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int): the index of the image component to be augmented.
coords_index (int): the index of the coordinate component to be augmented.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
copy, catch_exceptions: same as in :class:`AugmentImageComponent`
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self._nr_error = 0
exception_handler = ExceptionHandler(catch_exceptions)
def func(dp):
try:
with exception_handler.catch():
img, coords = dp[img_index], dp[coords_index]
_valid_coords(coords)
if copy:
......@@ -131,13 +148,6 @@ class AugmentImageCoordinates(MapData):
coords = self.augs._augment_coords(coords, prms)
dp[coords_index] = coords
return dp
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
super(AugmentImageCoordinates, self).__init__(ds, func)
......@@ -161,29 +171,27 @@ class AugmentImageComponents(MapData):
"""
def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True):
def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True, catch_exceptions=False):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of the image components.
coords_index: tuple of indices of the coordinates components.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
copy, catch_exceptions: same as in :class:`AugmentImageComponent`
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self.ds = ds
self._nr_error = 0
exception_handler = ExceptionHandler(catch_exceptions)
def func(dp):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
try:
with exception_handler.catch():
major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
im, prms = self.augs._augment_return_params(im)
......@@ -195,13 +203,6 @@ class AugmentImageComponents(MapData):
_valid_coords(coords)
dp[idx] = self.augs._augment_coords(coords, prms)
return dp
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
super(AugmentImageComponents, self).__init__(ds, func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment