Commit ea72115e authored by Yuxin Wu's avatar Yuxin Wu

hack shape for batch normalization

parent 9952c6c6
......@@ -8,6 +8,7 @@ from .base import DataFlow
from .imgaug import AugmentorList, Image
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData',
'AugmentImageComponent']
class BatchData(DataFlow):
......@@ -124,6 +125,19 @@ class FakeData(DataFlow):
yield [np.random.random(k) for k in self.shapes]
class MapData(DataFlow):
""" Map a function to the datapoint"""
def __init__(self, ds, func):
self.ds = ds
self.func = func
def size(self):
return self.ds.size()
def get_data(self):
for dp in self.ds.get_data():
yield self.func(dp)
class MapDataComponent(DataFlow):
""" Apply a function to the given index in the datapoint"""
def __init__(self, ds, func, index=0):
self.ds = ds
......@@ -138,6 +152,31 @@ class MapData(DataFlow):
dp[self.index] = self.func(dp[self.index])
yield dp
class RandomChooseData(DataFlow):
"""
Randomly choose from several dataflow. Stop producing when any of its dataflow stops.
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow, or list of (dataflow, probability) tuple
"""
if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0
self.df_lists = df_lists
else:
prob = 1.0 / len(df_lists)
self.df_lists = [(k, prob) for k in df_lists]
def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists]
probs = np.array([v[1] for v in self.df_lists])
try:
while True:
itr = np.random.choice(itrs, p=probs)
yield next(itr)
except StopIteration:
return
def AugmentImageComponent(ds, augmentors, index=0):
"""
Augment the image in each data point
......@@ -146,9 +185,9 @@ def AugmentImageComponent(ds, augmentors, index=0):
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
# TODO reset rng at the beginning of each get_data
# TODO reset rng at the beginning of each get_data
aug = AugmentorList(augmentors)
return MapData(
return MapDataComponent(
ds,
lambda img: aug.augment(Image(img)).arr,
index)
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from copy import copy
from ._common import layer_register
......@@ -37,7 +38,16 @@ def BatchNorm(x, is_training, gamma_init=1.0):
beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(gamma_init))
# XXX hack to clear shape. see tensorflow#1162
if shape[0] is not None:
x = tf.tile(x, tf.pack([1,1,1,1]))
hack_shape = copy(shape)
hack_shape[0] = None
x.set_shape(hack_shape)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
print batch_mean
ema = tf.train.ExponentialMovingAverage(decay=0.999)
ema_apply_op = ema.apply([batch_mean, batch_var])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment