Commit 4af48399 authored by Yuxin Wu's avatar Yuxin Wu

fix about box & imgaug

parent d5410902
......@@ -26,6 +26,7 @@ parser.add_argument('--index',
parser.add_argument('-n', '--number', help='number of images to dump',
default=10, type=int)
args = parser.parse_args()
logger.auto_set_dir(action='d')
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
......
......@@ -68,8 +68,15 @@ class BatchData(ProxyDataFlow):
tp = 'float32'
else:
tp = dt.dtype
result.append(
np.array([x[k] for x in data_holder], dtype=tp))
try:
result.append(
np.array([x[k] for x in data_holder], dtype=tp))
except KeyboardInterrupt:
raise
except:
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
return result
class FixedSizeData(ProxyDataFlow):
......
......@@ -145,9 +145,10 @@ class ILSVRC12(RNGDataFlow):
self.synset = meta.get_synset_1000()
if include_bb:
bbdir = os.path.join(dir, 'bbox') if not \
isinstance(include_bb, six.string_types) else include_bb
assert name == 'train', 'Bounding box only available for training'
self.bblist = ILSVRC12.get_training_bbox(
os.path.join(dir, 'bbox'), self.imglist)
self.bblist = ILSVRC12.get_training_bbox(bbdir, self.imglist)
self.include_bb = include_bb
def size(self):
......@@ -156,7 +157,7 @@ class ILSVRC12(RNGDataFlow):
def get_data(self):
"""
Produce original images of shape [h, w, 3], and label,
and optionally a bbox of [xmin, ymin, xmax, ymax] in [0, 1]
and optionally a bbox of [xmin, ymin, xmax, ymax]
"""
idxs = np.arange(len(self.imglist))
add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original')
......@@ -174,8 +175,8 @@ class ILSVRC12(RNGDataFlow):
im = np.expand_dims(im, 2).repeat(3,2)
if self.include_bb:
bb = self.bblist[k]
if not bb:
bb = [0, 0, 1, 1]
if bb is None:
bb = [0, 0, im.shape[1]-1, im.shape[0]-1]
yield [im, label, bb]
else:
yield [im, label]
......@@ -191,10 +192,10 @@ class ILSVRC12(RNGDataFlow):
box = root.find('object').find('bndbox').getchildren()
box = map(lambda x: float(x.text), box)
box[0] /= size[0]
box[1] /= size[1]
box[2] /= size[0]
box[3] /= size[1]
#box[0] /= size[0]
#box[1] /= size[1]
#box[2] /= size[0]
#box[3] /= size[1]
return np.asarray(box, dtype='float32')
with timed_operation('Loading Bounding Boxes ...'):
......
......@@ -8,7 +8,8 @@ from ...utils.rect import Rect
from six.moves import range
import numpy as np
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', 'RandomCropRandomShape']
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop',
'RandomCropRandomShape', 'perturb_BB']
class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """
......
......@@ -98,7 +98,7 @@ class Gamma(ImageAugmentor):
return self._rand_range(*self.range)
def _augment(self, img, gamma):
lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8')
img = img.astype('uint8')
img = np.clip(img, 0, 255).astype('uint8')
img = cv2.LUT(img, lut).astype('float32')
return img
......@@ -25,6 +25,10 @@ class RandomApplyAug(ImageAugmentor):
else:
return (False, None)
def reset_state(self):
super(RandomApplyAug, self).reset_state()
self.aug.reset_state()
def _augment(self, img, prm):
if not prm[0]:
return img
......@@ -44,6 +48,11 @@ class RandomChooseAug(ImageAugmentor):
prob = 1.0 / len(aug_lists)
self._init(locals())
def reset_state(self):
super(RandomChooseAug, self).reset_state()
for a in self.aug_lists:
a.reset_state()
def _get_augment_params(self, img):
aug_idx = self.rng.choice(len(self.aug_lists), p=self.prob)
aug_prm = self.aug_lists[aug_idx]._get_augment_params(img)
......
......@@ -45,6 +45,7 @@ def get_time_str():
return datetime.now().strftime('%m%d-%H%M%S')
# logger file and directory:
global LOG_FILE, LOG_DIR
LOG_DIR = None
def _set_file(path):
if os.path.isfile(path):
backup_name = path + '.' + get_time_str()
......@@ -63,10 +64,11 @@ def set_logger_dir(dirname, action=None):
"""
global LOG_FILE, LOG_DIR
if os.path.isdir(dirname):
_logger.warn("""\
if not action:
_logger.warn("""\
Directory {} exists! Please either backup/delete it, or use a new directory. \
If you're resuming from a previous run you can choose to keep it.""".format(dirname))
_logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
_logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
while not action:
action = input().lower().strip()
act = action
......@@ -97,9 +99,11 @@ def disable_logger():
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
globals()[func] = lambda x: None
def auto_set_dir(action=None):
def auto_set_dir(action=None, overwrite_setting=False):
""" set log directory to a subdir inside 'train_log', with the name being
the main python file currently running"""
if LOG_DIR is not None and not overwrite_setting:
return
mod = sys.modules['__main__']
basename = os.path.basename(mod.__file__)
set_logger_dir(
......
......@@ -67,10 +67,11 @@ class Rect(object):
return True
def roi(self, img):
assert self.validate(img.shape[:2])
assert self.validate(img.shape[:2]), "{} vs {}".format(self, img.shape[:2])
return img[self.y0:self.y1+1, self.x0:self.x1+1]
def expand(self, frac):
assert frac > 1.0, frac
neww = self.w * frac
newh = self.h * frac
newx = self.x - (neww - self.w) * 0.5
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment