Commit 27134e15 authored by Yuxin Wu's avatar Yuxin Wu

update imgaug

parent 78ce3a96
...@@ -100,7 +100,6 @@ class MapData(DataFlow): ...@@ -100,7 +100,6 @@ class MapData(DataFlow):
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
d = list(dp)
dp[self.index] = self.func(dp[self.index]) dp[self.index] = self.func(dp[self.index])
yield dp yield dp
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: _test.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import cv2
from . import *
augmentors = AugmentorList([
Contrast((0.2,1.8)),
Flip(horiz=True)
])
img = cv2.imread('cat.jpg')
img = Image(img)
augmentors.augment(img)
cv2.imshow(" ", img.arr)
cv2.waitKey()
...@@ -61,6 +61,6 @@ class AugmentorList(ImageAugmentor): ...@@ -61,6 +61,6 @@ class AugmentorList(ImageAugmentor):
def _augment(self, img): def _augment(self, img):
assert img.arr.ndim in [2, 3] assert img.arr.ndim in [2, 3]
img.arr = img.arr.astype('float32') / 255.0 img.arr = img.arr.astype('float32')
for aug in self.augs: for aug in self.augs:
aug.augment(img) aug.augment(img)
...@@ -19,7 +19,7 @@ class BrightnessAdd(ImageAugmentor): ...@@ -19,7 +19,7 @@ class BrightnessAdd(ImageAugmentor):
def _augment(self, img): def _augment(self, img):
v = self._rand_range(-self.delta, self.delta) v = self._rand_range(-self.delta, self.delta)
img.arr += v img.arr += v
img.arr = np.clip(img.arr, 0, 1) img.arr = np.clip(img.arr, 0, 255)
class Contrast(ImageAugmentor): class Contrast(ImageAugmentor):
""" """
...@@ -33,7 +33,7 @@ class Contrast(ImageAugmentor): ...@@ -33,7 +33,7 @@ class Contrast(ImageAugmentor):
r = self._rand_range(*self.factor_range) r = self._rand_range(*self.factor_range)
mean = np.mean(arr, axis=(0,1), keepdims=True) mean = np.mean(arr, axis=(0,1), keepdims=True)
img.arr = (arr - mean) * r + mean img.arr = (arr - mean) * r + mean
img.arr = np.clip(img.arr, 0, 1) img.arr = np.clip(img.arr, 0, 255)
class PerImageWhitening(ImageAugmentor): class PerImageWhitening(ImageAugmentor):
""" """
...@@ -41,11 +41,16 @@ class PerImageWhitening(ImageAugmentor): ...@@ -41,11 +41,16 @@ class PerImageWhitening(ImageAugmentor):
x = (x - mean) / adjusted_stddev x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels)) where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
""" """
def __init__(self): def __init__(self, all_channel=True):
self.all_channel = all_channel
pass pass
def _augment(self, img): def _augment(self, img):
mean = np.mean(img.arr, axis=(0,1), keepdims=True) if self.all_channel:
std = np.std(img.arr, axis=(0,1), keepdims=True) mean = np.mean(img.arr)
std = np.std(img.arr)
else:
mean = np.mean(img.arr, axis=(0,1), keepdims=True)
std = np.std(img.arr, axis=(0,1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape))) std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape)))
img.arr = (img.arr - mean) / std img.arr = (img.arr - mean) / std
...@@ -28,6 +28,7 @@ class Flip(ImageAugmentor): ...@@ -28,6 +28,7 @@ class Flip(ImageAugmentor):
self._init() self._init()
def _augment(self, img): def _augment(self, img):
# TODO XXX prob is wrong for both mode
if self._rand_range() < self.prob: if self._rand_range() < self.prob:
img.arr = cv2.flip(img.arr, self.code) img.arr = cv2.flip(img.arr, self.code)
if img.coords: if img.coords:
......
...@@ -161,8 +161,9 @@ def start_train(config): ...@@ -161,8 +161,9 @@ def start_train(config):
summary_grads(grads) summary_grads(grads)
avg_maintain_op = summary_moving_average(cost_var) avg_maintain_op = summary_moving_average(cost_var)
with tf.control_dependencies([avg_maintain_op]): train_op = tf.group(
train_op = config.optimizer.apply_gradients(grads, get_global_step_var()) config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
describe_model() describe_model()
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment