Commit 22221fd8 authored by Yuxin Wu's avatar Yuxin Wu

bugfix in ResizeShortestEdge.

parent 2bbe988b
......@@ -5,6 +5,7 @@
from __future__ import division
import numpy as np
from copy import copy
import pprint
from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
......@@ -136,8 +137,11 @@ class BatchData(ProxyDataFlow):
np.asarray([x[k] for x in data_holder], dtype=tp))
except KeyboardInterrupt:
raise
except:
except Exception as e: # noqa
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
if isinstance(dt, np.ndarray):
s = pprint.pformat([x[k].shape for x in data_holder])
logger.error("Shape of all arrays to be batched: " + s)
try:
# open an ipython shell if possible
import IPython as IP; IP.embed() # noqa
......
......@@ -101,13 +101,16 @@ class ResizeShortestEdge(ImageAugmentor):
Args:
size (int): the size to resize the shortest edge to.
"""
size = size * 1.0
size = int(size)
self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
scale = self.size / min(h, w)
newh, neww = map(int, [scale * h, scale * w])
scale = self.size * 1.0 / min(h, w)
if h < w:
newh, neww = self.size, int(scale * w)
else:
newh, neww = int(scale * h), self.size
return (h, w, newh, neww)
def _augment(self, img, param):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment