Commit e2261920 authored by Yuxin Wu's avatar Yuxin Wu

fix linearwrap + tf function

parent 31236d84
...@@ -155,6 +155,8 @@ def split_input(img): ...@@ -155,6 +155,8 @@ def split_input(img):
def colorization_input(img): def colorization_input(img):
assert img.ndim == 3 assert img.ndim == 3
if min(img.shape[:2]) < SHAPE:
return None # skip the image
# create gray + RGB pairs # create gray + RGB pairs
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis] gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
return [gray, img] return [gray, img]
...@@ -162,7 +164,6 @@ def colorization_input(img): ...@@ -162,7 +164,6 @@ def colorization_input(img):
def get_data(): def get_data():
datadir = args.data datadir = args.data
# assume each image is 512x256 split to left and right
imgs = glob.glob(os.path.join(datadir, '*.jpg')) imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
......
...@@ -7,6 +7,7 @@ import cv2 ...@@ -7,6 +7,7 @@ import cv2
from .base import RNGDataFlow from .base import RNGDataFlow
from .common import MapDataComponent, MapData from .common import MapDataComponent, MapData
from .imgaug import AugmentorList from .imgaug import AugmentorList
from ..utils import logger
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
...@@ -81,14 +82,23 @@ class AugmentImageComponents(MapData): ...@@ -81,14 +82,23 @@ class AugmentImageComponents(MapData):
""" """
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self.ds = ds self.ds = ds
self._nr_error = 0
def func(dp): def func(dp):
im = dp[index[0]] try:
im, prms = self.augs._augment_return_params(im) im = dp[index[0]]
dp[index[0]] = im im, prms = self.augs._augment_return_params(im)
for idx in index[1:]: dp[index[0]] = im
dp[idx] = self.augs._augment(dp[idx], prms) for idx in index[1:]:
return dp dp[idx] = self.augs._augment(dp[idx], prms)
return dp
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0:
logger.warn("Got {} augmentation errors.".format(self._nr_error))
return None
super(AugmentImageComponents, self).__init__(ds, func) super(AugmentImageComponents, self).__init__(ds, func)
......
...@@ -60,7 +60,9 @@ class LinearWrap(object): ...@@ -60,7 +60,9 @@ class LinearWrap(object):
logger.warn( logger.warn(
"You're calling LinearWrap.__getattr__ with {}:" "You're calling LinearWrap.__getattr__ with {}:"
" neither a layer nor 'tf'!".format(layer_name)) " neither a layer nor 'tf'!".format(layer_name))
assert isinstance(layer, ModuleType) import tensorflow as tf # noqa
layer = eval(layer_name)
assert isinstance(layer, ModuleType), layer
return LinearWrap._TFModuleFunc(layer, self._t) return LinearWrap._TFModuleFunc(layer, self._t)
def apply(self, func, *args, **kwargs): def apply(self, func, *args, **kwargs):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment