Commit 7fe010cb authored by Yuxin Wu's avatar Yuxin Wu

batch norm for fc

parent a19a57b2
...@@ -51,6 +51,10 @@ def read_cifar10(filenames): ...@@ -51,6 +51,10 @@ def read_cifar10(filenames):
yield [img, label[k]] yield [img, label[k]]
class Cifar10(DataFlow): class Cifar10(DataFlow):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def __init__(self, train_or_test, dir=None): def __init__(self, train_or_test, dir=None):
""" """
Args: Args:
......
...@@ -96,6 +96,10 @@ class DataSet(object): ...@@ -96,6 +96,10 @@ class DataSet(object):
return self._num_examples return self._num_examples
class Mnist(DataFlow): class Mnist(DataFlow):
"""
Return [image, label],
image is 28x28 in the range [0,1]
"""
def __init__(self, train_or_test, dir=None): def __init__(self, train_or_test, dir=None):
""" """
Args: Args:
......
...@@ -11,7 +11,7 @@ __all__ = ['BatchNorm'] ...@@ -11,7 +11,7 @@ __all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow # http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# Only work for 4D tensor right now: #804 # TF batch_norm only works for 4D tensor right now: #804
@layer_register() @layer_register()
def BatchNorm(x, is_training): def BatchNorm(x, is_training):
""" """
...@@ -22,12 +22,15 @@ def BatchNorm(x, is_training): ...@@ -22,12 +22,15 @@ def BatchNorm(x, is_training):
Whole-population mean/variance is calculated by a running-average mean/variance, with decay rate 0.999 Whole-population mean/variance is calculated by a running-average mean/variance, with decay rate 0.999
Epsilon for variance is set to 1e-5, as is torch/nn: https://github.com/torch/nn/blob/master/BatchNormalization.lua Epsilon for variance is set to 1e-5, as is torch/nn: https://github.com/torch/nn/blob/master/BatchNormalization.lua
x: BHWC tensor x: BHWC tensor or a vector
is_training: bool is_training: bool
""" """
EPS = 1e-5 EPS = 1e-5
is_training = bool(is_training) is_training = bool(is_training)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, shape[1]])
shape = x.get_shape().as_list()
assert len(shape) == 4 assert len(shape) == 4
n_out = shape[-1] # channel n_out = shape[-1] # channel
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment