Commit 1e71e8f9 authored by Yuxin Wu's avatar Yuxin Wu

augment images together

parent ac6e140f
......@@ -16,6 +16,8 @@ class DataFlow(object):
def get_data(self):
"""
A generator to generate data as a list.
Datapoint should be a mutable list.
Each component should be assumed immutable.
"""
def size(self):
......
......@@ -183,8 +183,7 @@ class MapDataComponent(ProxyDataFlow):
for dp in self.ds.get_data():
repl = self.func(dp[self.index])
if repl is not None:
dp = copy.deepcopy(dp) # avoid modifying the original dp
dp[self.index] = repl
dp[self.index] = repl # NOTE modifying
yield dp
class RandomChooseData(DataFlow):
......
......@@ -98,7 +98,7 @@ class Cifar10(DataFlow):
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
yield self.data[k]
yield copy.copy(self.data[k])
def get_per_pixel_mean(self):
"""
......
......@@ -6,10 +6,10 @@ import numpy as np
import cv2
import copy
from .base import DataFlow, ProxyDataFlow
from .common import MapDataComponent
from .common import MapDataComponent, MapData
from .imgaug import AugmentorList
__all__ = ['ImageFromFile', 'AugmentImageComponent']
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImagesTogether']
class ImageFromFile(DataFlow):
""" Generate rgb images from list of files """
......@@ -56,3 +56,26 @@ class AugmentImageComponent(MapDataComponent):
self.augs.reset_state()
class AugmentImagesTogether(MapData):
def __init__(self, ds, augmentors, index=(0,1)):
"""
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: tuple of indices of the image components
"""
self.augs = AugmentorList(augmentors)
self.ds = ds
def func(dp):
im = dp[index[0]]
im, prms = self.augs._augment_return_params(im)
dp[index[0]] = im
for idx in index[1:]:
dp[idx] = self.augs._augment(dp[idx], prms)
return dp
super(AugmentImagesTogether, self).__init__(ds, func)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
......@@ -8,7 +8,8 @@ import numpy
from ._common import *
from ..tfutils.symbolic_functions import *
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling']
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample']
@layer_register()
def MaxPooling(x, shape, stride=None, padding='VALID'):
......
......@@ -65,7 +65,7 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
:returns: class-balanced binary classification cross entropy loss
"""
z = batch_flatten(pred)
y = batch_flatten(label)
y = tf.cast(batch_flatten(label), tf.float32)
count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment