Commit 7fe010cb authored by Yuxin Wu's avatar Yuxin Wu

batch norm for fc

parent a19a57b2
......@@ -51,6 +51,10 @@ def read_cifar10(filenames):
yield [img, label[k]]
class Cifar10(DataFlow):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def __init__(self, train_or_test, dir=None):
"""
Args:
......
......@@ -96,6 +96,10 @@ class DataSet(object):
return self._num_examples
class Mnist(DataFlow):
"""
Return [image, label],
image is 28x28 in the range [0,1]
"""
def __init__(self, train_or_test, dir=None):
"""
Args:
......
......@@ -11,7 +11,7 @@ __all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# Only work for 4D tensor right now: #804
# TF batch_norm only works for 4D tensor right now: #804
@layer_register()
def BatchNorm(x, is_training):
"""
......@@ -22,12 +22,15 @@ def BatchNorm(x, is_training):
Whole-population mean/variance is calculated by a running-average mean/variance, with decay rate 0.999
Epsilon for variance is set to 1e-5, as is torch/nn: https://github.com/torch/nn/blob/master/BatchNormalization.lua
x: BHWC tensor
x: BHWC tensor or a vector
is_training: bool
"""
EPS = 1e-5
is_training = bool(is_training)
shape = x.get_shape().as_list()
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, shape[1]])
shape = x.get_shape().as_list()
assert len(shape) == 4
n_out = shape[-1] # channel
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment