Commit 2d923b91 authored by Yuxin Wu's avatar Yuxin Wu

update augmentors

parent cdd71bfe
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import copy
from .base import DataFlow
from .imgaug import AugmentorList, Image
......@@ -149,6 +150,7 @@ class MapDataComponent(DataFlow):
def get_data(self):
for dp in self.ds.get_data():
dp = copy.deepcopy(dp) # avoid modifying the original dp
dp[self.index] = self.func(dp[self.index])
yield dp
......
......@@ -6,6 +6,7 @@ import os, sys
import pickle
import numpy as np
from six.moves import urllib
import copy
import tarfile
import logging
......@@ -78,7 +79,7 @@ class Cifar10(DataFlow):
if train_or_test == 'train':
self.fs = fnames[:5]
else:
self.fs = fnames[-1]
self.fs = [fnames[-1]]
for f in self.fs:
if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f)
......
......@@ -10,7 +10,7 @@ __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize']
class BrightnessAdd(ImageAugmentor):
"""
Randomly add a value within [-delta,delta], and clip in [0,1]
Randomly add a value within [-delta,delta], and clip in [0,255]
"""
def __init__(self, delta):
assert delta > 0
......@@ -24,6 +24,7 @@ class BrightnessAdd(ImageAugmentor):
class Contrast(ImageAugmentor):
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
"""
def __init__(self, factor_range):
self._init(locals())
......
......@@ -7,7 +7,7 @@ from .base import ImageAugmentor
import numpy as np
import cv2
__all__ = ['Flip']
__all__ = ['Flip', 'MapImage']
class Flip(ImageAugmentor):
def __init__(self, horiz=False, vert=False, prob=0.5):
......@@ -34,3 +34,9 @@ class Flip(ImageAugmentor):
raise NotImplementedError()
class MapImage(ImageAugmentor):
def __init__(self, func):
self.func = func
def _augment(self, img):
img.arr = self.func(img.arr)
......@@ -62,13 +62,13 @@ class Trainer(object):
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self):
self._init_summary()
callbacks = self.config.callbacks
callbacks.before_train(self)
with self.sess.as_default():
try:
self._init_summary()
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train(self)
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment