Commit 575138ff authored by Yuxin Wu's avatar Yuxin Wu

svhn config & rng in dataflow

parent b2f8fec3
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: example_svhn_digit.py
# File: svhn_fast.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
......@@ -19,7 +19,7 @@ from tensorpack.dataflow import imgaug
"""
SVHN convnet.
About 2.9% validation error after 70 epoch.
About 3.0% validation error after 120 epoch. 2.7% after 250 epoch.
"""
class Model(ModelDesc):
......@@ -103,8 +103,8 @@ def get_config():
lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
decay_steps=train.size() * 30,
decay_rate=0.5, staircase=True, name='learning_rate')
decay_steps=train.size() * 60,
decay_rate=0.2, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
......
......@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import random
import copy
from six.moves import range
from .base import DataFlow, ProxyDataFlow
......@@ -166,6 +165,7 @@ class RandomChooseData(DataFlow):
else:
prob = 1.0 / len(df_lists)
self.df_lists = [(k, prob) for k in df_lists]
self.rng = get_rng(self)
def reset_state(self):
for d in self.df_lists:
......@@ -173,13 +173,14 @@ class RandomChooseData(DataFlow):
d[0].reset_state()
else:
d.reset_state()
self.rng = get_rng(self)
def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists]
probs = np.array([v[1] for v in self.df_lists])
try:
while True:
itr = np.random.choice(itrs, p=probs)
itr = self.rng.choice(itrs, p=probs)
yield next(itr)
except StopIteration:
return
......@@ -196,10 +197,12 @@ class RandomMixData(DataFlow):
"""
self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists]
self.rng = get_rng(self)
def reset_state(self):
for d in self.df_lists:
d.reset_state()
self.rng = get_rng(self)
def size(self):
return sum(self.sizes)
......@@ -207,7 +210,7 @@ class RandomMixData(DataFlow):
def get_data(self):
sums = np.cumsum(self.sizes)
idxs = np.arange(self.size())
np.random.shuffle(idxs)
self.rng.shuffle(idxs)
idxs = np.array(map(
lambda x: np.searchsorted(sums, x, 'right'), idxs))
itrs = [k.get_data() for k in self.df_lists]
......
......@@ -12,7 +12,7 @@ import copy
import tarfile
import logging
from ...utils import logger
from ...utils import logger, get_rng
from ..base import DataFlow
__all__ = ['Cifar10']
......@@ -93,14 +93,18 @@ class Cifar10(DataFlow):
self.dir = dir
self.data = read_cifar10(self.fs)
self.shuffle = shuffle
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
def size(self):
return 50000 if self.train_or_test == 'train' else 10000
def get_data(self):
idxs = list(range(len(self.data)))
idxs = np.arange(len(self.data))
if self.shuffle:
random.shuffle(idxs)
self.rng.shuffle(idxs)
for k in idxs:
yield self.data[k]
......
......@@ -13,8 +13,10 @@ __all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# TF batch_norm only works for 4D tensor right now: #804
# decay: 0.999 not good for resnet, torch use 0.9 by default
# eps: torch: 1e-5. Lasagne: 1e-4
@layer_register()
def BatchNorm(x, is_training=True, gamma_init=1.0):
def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
"""
Batch normalization layer as described in:
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
......@@ -24,10 +26,10 @@ def BatchNorm(x, is_training=True, gamma_init=1.0):
Epsilon for variance is set to 1e-5, as is torch/nn: https://github.com/torch/nn/blob/master/BatchNormalization.lua
x: BHWC tensor or a vector
is_training: bool
use_local_stat: bool. whether to use mean/var of this batch or the running
average. Usually set to True in training and False in testing
"""
EPS = 1e-5
is_training = bool(is_training)
shape = x.get_shape().as_list()
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, shape[1]])
......@@ -36,8 +38,9 @@ def BatchNorm(x, is_training=True, gamma_init=1.0):
n_out = shape[-1] # channel
beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(gamma_init))
gamma = tf.get_variable(
'gamma', [n_out],
initializer=tf.constant_initializer(1.0))
# XXX hack to clear shape. see tensorflow#1162
if shape[0] is not None:
......@@ -48,16 +51,16 @@ def BatchNorm(x, is_training=True, gamma_init=1.0):
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.999)
ema = tf.train.ExponentialMovingAverage(decay=decay)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
if is_training:
if use_local_stat:
with tf.control_dependencies([ema_apply_op]):
return tf.nn.batch_norm_with_global_normalization(
x, batch_mean, batch_var, beta, gamma, EPS, True)
x, batch_mean, batch_var, beta, gamma, epsilon, True)
else:
batch = tf.cast(tf.shape(x)[0], tf.float32)
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_norm_with_global_normalization(
x, mean, var, beta, gamma, EPS, True)
x, mean, var, beta, gamma, epsilon, True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment