Commit ea72115e authored by Yuxin Wu's avatar Yuxin Wu

hack shape for batch normalization

parent 9952c6c6
...@@ -8,6 +8,7 @@ from .base import DataFlow ...@@ -8,6 +8,7 @@ from .base import DataFlow
from .imgaug import AugmentorList, Image from .imgaug import AugmentorList, Image
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData',
'AugmentImageComponent'] 'AugmentImageComponent']
class BatchData(DataFlow): class BatchData(DataFlow):
...@@ -124,6 +125,19 @@ class FakeData(DataFlow): ...@@ -124,6 +125,19 @@ class FakeData(DataFlow):
yield [np.random.random(k) for k in self.shapes] yield [np.random.random(k) for k in self.shapes]
class MapData(DataFlow): class MapData(DataFlow):
""" Map a function to the datapoint"""
def __init__(self, ds, func):
self.ds = ds
self.func = func
def size(self):
return self.ds.size()
def get_data(self):
for dp in self.ds.get_data():
yield self.func(dp)
class MapDataComponent(DataFlow):
""" Apply a function to the given index in the datapoint""" """ Apply a function to the given index in the datapoint"""
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
self.ds = ds self.ds = ds
...@@ -138,6 +152,31 @@ class MapData(DataFlow): ...@@ -138,6 +152,31 @@ class MapData(DataFlow):
dp[self.index] = self.func(dp[self.index]) dp[self.index] = self.func(dp[self.index])
yield dp yield dp
class RandomChooseData(DataFlow):
"""
Randomly choose from several dataflow. Stop producing when any of its dataflow stops.
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow, or list of (dataflow, probability) tuple
"""
if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0
self.df_lists = df_lists
else:
prob = 1.0 / len(df_lists)
self.df_lists = [(k, prob) for k in df_lists]
def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists]
probs = np.array([v[1] for v in self.df_lists])
try:
while True:
itr = np.random.choice(itrs, p=probs)
yield next(itr)
except StopIteration:
return
def AugmentImageComponent(ds, augmentors, index=0): def AugmentImageComponent(ds, augmentors, index=0):
""" """
Augment the image in each data point Augment the image in each data point
...@@ -146,9 +185,9 @@ def AugmentImageComponent(ds, augmentors, index=0): ...@@ -146,9 +185,9 @@ def AugmentImageComponent(ds, augmentors, index=0):
augmentors: a list of ImageAugmentor instance augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0 index: the index of image in each data point. default to be 0
""" """
# TODO reset rng at the beginning of each get_data # TODO reset rng at the beginning of each get_data
aug = AugmentorList(augmentors) aug = AugmentorList(augmentors)
return MapData( return MapDataComponent(
ds, ds,
lambda img: aug.augment(Image(img)).arr, lambda img: aug.augment(Image(img)).arr,
index) index)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from copy import copy
from ._common import layer_register from ._common import layer_register
...@@ -37,7 +38,16 @@ def BatchNorm(x, is_training, gamma_init=1.0): ...@@ -37,7 +38,16 @@ def BatchNorm(x, is_training, gamma_init=1.0):
beta = tf.get_variable('beta', [n_out]) beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(gamma_init)) initializer=tf.constant_initializer(gamma_init))
# XXX hack to clear shape. see tensorflow#1162
if shape[0] is not None:
x = tf.tile(x, tf.pack([1,1,1,1]))
hack_shape = copy(shape)
hack_shape[0] = None
x.set_shape(hack_shape)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
print batch_mean
ema = tf.train.ExponentialMovingAverage(decay=0.999) ema = tf.train.ExponentialMovingAverage(decay=0.999)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment