Commit 2d923b91 authored by Yuxin Wu's avatar Yuxin Wu

update augmentors

parent cdd71bfe
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
import copy
from .base import DataFlow from .base import DataFlow
from .imgaug import AugmentorList, Image from .imgaug import AugmentorList, Image
...@@ -149,6 +150,7 @@ class MapDataComponent(DataFlow): ...@@ -149,6 +150,7 @@ class MapDataComponent(DataFlow):
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
dp = copy.deepcopy(dp) # avoid modifying the original dp
dp[self.index] = self.func(dp[self.index]) dp[self.index] = self.func(dp[self.index])
yield dp yield dp
......
...@@ -6,6 +6,7 @@ import os, sys ...@@ -6,6 +6,7 @@ import os, sys
import pickle import pickle
import numpy as np import numpy as np
from six.moves import urllib from six.moves import urllib
import copy
import tarfile import tarfile
import logging import logging
...@@ -78,7 +79,7 @@ class Cifar10(DataFlow): ...@@ -78,7 +79,7 @@ class Cifar10(DataFlow):
if train_or_test == 'train': if train_or_test == 'train':
self.fs = fnames[:5] self.fs = fnames[:5]
else: else:
self.fs = fnames[-1] self.fs = [fnames[-1]]
for f in self.fs: for f in self.fs:
if not os.path.isfile(f): if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f) raise ValueError('Failed to find file: ' + f)
......
...@@ -10,7 +10,7 @@ __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize'] ...@@ -10,7 +10,7 @@ __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize']
class BrightnessAdd(ImageAugmentor): class BrightnessAdd(ImageAugmentor):
""" """
Randomly add a value within [-delta,delta], and clip in [0,1] Randomly add a value within [-delta,delta], and clip in [0,255]
""" """
def __init__(self, delta): def __init__(self, delta):
assert delta > 0 assert delta > 0
...@@ -24,6 +24,7 @@ class BrightnessAdd(ImageAugmentor): ...@@ -24,6 +24,7 @@ class BrightnessAdd(ImageAugmentor):
class Contrast(ImageAugmentor): class Contrast(ImageAugmentor):
""" """
Apply x = (x - mean) * contrast_factor + mean to each channel Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
""" """
def __init__(self, factor_range): def __init__(self, factor_range):
self._init(locals()) self._init(locals())
......
...@@ -7,7 +7,7 @@ from .base import ImageAugmentor ...@@ -7,7 +7,7 @@ from .base import ImageAugmentor
import numpy as np import numpy as np
import cv2 import cv2
__all__ = ['Flip'] __all__ = ['Flip', 'MapImage']
class Flip(ImageAugmentor): class Flip(ImageAugmentor):
def __init__(self, horiz=False, vert=False, prob=0.5): def __init__(self, horiz=False, vert=False, prob=0.5):
...@@ -34,3 +34,9 @@ class Flip(ImageAugmentor): ...@@ -34,3 +34,9 @@ class Flip(ImageAugmentor):
raise NotImplementedError() raise NotImplementedError()
class MapImage(ImageAugmentor):
def __init__(self, func):
self.func = func
def _augment(self, img):
img.arr = self.func(img.arr)
...@@ -62,13 +62,13 @@ class Trainer(object): ...@@ -62,13 +62,13 @@ class Trainer(object):
self.summary_writer.add_summary(summary, self.global_step) self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self): def main_loop(self):
self._init_summary()
callbacks = self.config.callbacks callbacks = self.config.callbacks
callbacks.before_train(self)
with self.sess.as_default(): with self.sess.as_default():
try: try:
self._init_summary()
self.global_step = get_global_step() self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step)) logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train(self)
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch): for epoch in xrange(1, self.config.max_epoch):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment