Commit 1e71e8f9 authored by Yuxin Wu's avatar Yuxin Wu

augment images together

parent ac6e140f
...@@ -16,6 +16,8 @@ class DataFlow(object): ...@@ -16,6 +16,8 @@ class DataFlow(object):
def get_data(self): def get_data(self):
""" """
A generator to generate data as a list. A generator to generate data as a list.
Datapoint should be a mutable list.
Each component should be assumed immutable.
""" """
def size(self): def size(self):
......
...@@ -183,8 +183,7 @@ class MapDataComponent(ProxyDataFlow): ...@@ -183,8 +183,7 @@ class MapDataComponent(ProxyDataFlow):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
repl = self.func(dp[self.index]) repl = self.func(dp[self.index])
if repl is not None: if repl is not None:
dp = copy.deepcopy(dp) # avoid modifying the original dp dp[self.index] = repl # NOTE modifying
dp[self.index] = repl
yield dp yield dp
class RandomChooseData(DataFlow): class RandomChooseData(DataFlow):
......
...@@ -98,7 +98,7 @@ class Cifar10(DataFlow): ...@@ -98,7 +98,7 @@ class Cifar10(DataFlow):
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
yield self.data[k] yield copy.copy(self.data[k])
def get_per_pixel_mean(self): def get_per_pixel_mean(self):
""" """
......
...@@ -6,10 +6,10 @@ import numpy as np ...@@ -6,10 +6,10 @@ import numpy as np
import cv2 import cv2
import copy import copy
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from .common import MapDataComponent from .common import MapDataComponent, MapData
from .imgaug import AugmentorList from .imgaug import AugmentorList
__all__ = ['ImageFromFile', 'AugmentImageComponent'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImagesTogether']
class ImageFromFile(DataFlow): class ImageFromFile(DataFlow):
""" Generate rgb images from list of files """ """ Generate rgb images from list of files """
...@@ -56,3 +56,26 @@ class AugmentImageComponent(MapDataComponent): ...@@ -56,3 +56,26 @@ class AugmentImageComponent(MapDataComponent):
self.augs.reset_state() self.augs.reset_state()
class AugmentImagesTogether(MapData):
def __init__(self, ds, augmentors, index=(0,1)):
"""
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: tuple of indices of the image components
"""
self.augs = AugmentorList(augmentors)
self.ds = ds
def func(dp):
im = dp[index[0]]
im, prms = self.augs._augment_return_params(im)
dp[index[0]] = im
for idx in index[1:]:
dp[idx] = self.augs._augment(dp[idx], prms)
return dp
super(AugmentImagesTogether, self).__init__(ds, func)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
...@@ -8,7 +8,8 @@ import numpy ...@@ -8,7 +8,8 @@ import numpy
from ._common import * from ._common import *
from ..tfutils.symbolic_functions import * from ..tfutils.symbolic_functions import *
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling'] __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample']
@layer_register() @layer_register()
def MaxPooling(x, shape, stride=None, padding='VALID'): def MaxPooling(x, shape, stride=None, padding='VALID'):
......
...@@ -65,7 +65,7 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l ...@@ -65,7 +65,7 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
:returns: class-balanced binary classification cross entropy loss :returns: class-balanced binary classification cross entropy loss
""" """
z = batch_flatten(pred) z = batch_flatten(pred)
y = batch_flatten(label) y = tf.cast(batch_flatten(label), tf.float32)
count_neg = tf.reduce_sum(1. - y) count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y) count_pos = tf.reduce_sum(y)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment