Commit 27ea2836 authored by Yuxin Wu's avatar Yuxin Wu

doc for prefetch. brightnessadd -> brightness

parent 3e876599
...@@ -131,7 +131,7 @@ def get_data(train_or_test): ...@@ -131,7 +131,7 @@ def get_data(train_or_test):
imgaug.CenterPaste((40, 40)), imgaug.CenterPaste((40, 40)),
imgaug.RandomCrop((32, 32)), imgaug.RandomCrop((32, 32)),
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
#imgaug.BrightnessAdd(20), #imgaug.Brightness(20),
#imgaug.Contrast((0.6,1.4)), #imgaug.Contrast((0.6,1.4)),
imgaug.MapImage(lambda x: x - pp_mean), imgaug.MapImage(lambda x: x - pp_mean),
] ]
......
...@@ -134,7 +134,7 @@ def get_data(train_or_test): ...@@ -134,7 +134,7 @@ def get_data(train_or_test):
imgaug.CenterPaste((40, 40)), imgaug.CenterPaste((40, 40)),
imgaug.RandomCrop((32, 32)), imgaug.RandomCrop((32, 32)),
#imgaug.Flip(horiz=True), #imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(10), imgaug.Brightness(10),
imgaug.Contrast((0.8,1.2)), imgaug.Contrast((0.8,1.2)),
imgaug.GaussianDeform( # this is slow imgaug.GaussianDeform( # this is slow
[(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)], [(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
......
...@@ -82,7 +82,7 @@ def get_data(train_or_test): ...@@ -82,7 +82,7 @@ def get_data(train_or_test):
augmentors = [ augmentors = [
imgaug.RandomCrop((30, 30)), imgaug.RandomCrop((30, 30)),
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(63), imgaug.Brightness(63),
imgaug.Contrast((0.2,1.8)), imgaug.Contrast((0.2,1.8)),
imgaug.GaussianDeform( imgaug.GaussianDeform(
[(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)], [(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
......
...@@ -76,7 +76,7 @@ def get_config(): ...@@ -76,7 +76,7 @@ def get_config():
augmentors = [ augmentors = [
imgaug.Resize((40, 40)), imgaug.Resize((40, 40)),
imgaug.BrightnessAdd(30), imgaug.Brightness(30),
imgaug.Contrast((0.5,1.5)), imgaug.Contrast((0.5,1.5)),
imgaug.GaussianDeform( # this is slow imgaug.GaussianDeform( # this is slow
[(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)], [(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
......
...@@ -26,7 +26,10 @@ class DataFlow(object): ...@@ -26,7 +26,10 @@ class DataFlow(object):
def reset_state(self): def reset_state(self):
""" """
Reset state of the dataflow (usually the random seed) Reset state of the dataflow,
for example, RNG **HAS** to be reset here if used in the DataFlow.
Otherwise it may not work well with prefetching, because different
processes will have the same RNG state.
""" """
pass pass
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
from .base import ImageAugmentor from .base import ImageAugmentor
import numpy as np import numpy as np
__all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize'] __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize']
class BrightnessAdd(ImageAugmentor): class Brightness(ImageAugmentor):
""" """
Random adjust brightness. Random adjust brightness.
""" """
......
...@@ -20,6 +20,7 @@ class PrefetchProcess(multiprocessing.Process): ...@@ -20,6 +20,7 @@ class PrefetchProcess(multiprocessing.Process):
self.queue = queue self.queue = queue
def run(self): def run(self):
# reset RNG of ds so each process will produce different data
self.ds.reset_state() self.ds.reset_state()
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment