Commit 575138ff authored by Yuxin Wu's avatar Yuxin Wu

svhn config & rng in dataflow

parent b2f8fec3
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: example_svhn_digit.py # File: svhn_fast.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
...@@ -19,7 +19,7 @@ from tensorpack.dataflow import imgaug ...@@ -19,7 +19,7 @@ from tensorpack.dataflow import imgaug
""" """
SVHN convnet. SVHN convnet.
About 2.9% validation error after 70 epoch. About 3.0% validation error after 120 epoch. 2.7% after 250 epoch.
""" """
class Model(ModelDesc): class Model(ModelDesc):
...@@ -103,8 +103,8 @@ def get_config(): ...@@ -103,8 +103,8 @@ def get_config():
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-3, learning_rate=1e-3,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=train.size() * 30, decay_steps=train.size() * 60,
decay_rate=0.5, staircase=True, name='learning_rate') decay_rate=0.2, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
import random
import copy import copy
from six.moves import range from six.moves import range
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
...@@ -166,6 +165,7 @@ class RandomChooseData(DataFlow): ...@@ -166,6 +165,7 @@ class RandomChooseData(DataFlow):
else: else:
prob = 1.0 / len(df_lists) prob = 1.0 / len(df_lists)
self.df_lists = [(k, prob) for k in df_lists] self.df_lists = [(k, prob) for k in df_lists]
self.rng = get_rng(self)
def reset_state(self): def reset_state(self):
for d in self.df_lists: for d in self.df_lists:
...@@ -173,13 +173,14 @@ class RandomChooseData(DataFlow): ...@@ -173,13 +173,14 @@ class RandomChooseData(DataFlow):
d[0].reset_state() d[0].reset_state()
else: else:
d.reset_state() d.reset_state()
self.rng = get_rng(self)
def get_data(self): def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists] itrs = [v[0].get_data() for v in self.df_lists]
probs = np.array([v[1] for v in self.df_lists]) probs = np.array([v[1] for v in self.df_lists])
try: try:
while True: while True:
itr = np.random.choice(itrs, p=probs) itr = self.rng.choice(itrs, p=probs)
yield next(itr) yield next(itr)
except StopIteration: except StopIteration:
return return
...@@ -196,10 +197,12 @@ class RandomMixData(DataFlow): ...@@ -196,10 +197,12 @@ class RandomMixData(DataFlow):
""" """
self.df_lists = df_lists self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists] self.sizes = [k.size() for k in self.df_lists]
self.rng = get_rng(self)
def reset_state(self): def reset_state(self):
for d in self.df_lists: for d in self.df_lists:
d.reset_state() d.reset_state()
self.rng = get_rng(self)
def size(self): def size(self):
return sum(self.sizes) return sum(self.sizes)
...@@ -207,7 +210,7 @@ class RandomMixData(DataFlow): ...@@ -207,7 +210,7 @@ class RandomMixData(DataFlow):
def get_data(self): def get_data(self):
sums = np.cumsum(self.sizes) sums = np.cumsum(self.sizes)
idxs = np.arange(self.size()) idxs = np.arange(self.size())
np.random.shuffle(idxs) self.rng.shuffle(idxs)
idxs = np.array(map( idxs = np.array(map(
lambda x: np.searchsorted(sums, x, 'right'), idxs)) lambda x: np.searchsorted(sums, x, 'right'), idxs))
itrs = [k.get_data() for k in self.df_lists] itrs = [k.get_data() for k in self.df_lists]
......
...@@ -12,7 +12,7 @@ import copy ...@@ -12,7 +12,7 @@ import copy
import tarfile import tarfile
import logging import logging
from ...utils import logger from ...utils import logger, get_rng
from ..base import DataFlow from ..base import DataFlow
__all__ = ['Cifar10'] __all__ = ['Cifar10']
...@@ -93,14 +93,18 @@ class Cifar10(DataFlow): ...@@ -93,14 +93,18 @@ class Cifar10(DataFlow):
self.dir = dir self.dir = dir
self.data = read_cifar10(self.fs) self.data = read_cifar10(self.fs)
self.shuffle = shuffle self.shuffle = shuffle
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
def size(self): def size(self):
return 50000 if self.train_or_test == 'train' else 10000 return 50000 if self.train_or_test == 'train' else 10000
def get_data(self): def get_data(self):
idxs = list(range(len(self.data))) idxs = np.arange(len(self.data))
if self.shuffle: if self.shuffle:
random.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
yield self.data[k] yield self.data[k]
......
...@@ -13,8 +13,10 @@ __all__ = ['BatchNorm'] ...@@ -13,8 +13,10 @@ __all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow # http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# TF batch_norm only works for 4D tensor right now: #804 # TF batch_norm only works for 4D tensor right now: #804
# decay: 0.999 not good for resnet, torch use 0.9 by default
# eps: torch: 1e-5. Lasagne: 1e-4
@layer_register() @layer_register()
def BatchNorm(x, is_training=True, gamma_init=1.0): def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
""" """
Batch normalization layer as described in: Batch normalization layer as described in:
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
...@@ -24,10 +26,10 @@ def BatchNorm(x, is_training=True, gamma_init=1.0): ...@@ -24,10 +26,10 @@ def BatchNorm(x, is_training=True, gamma_init=1.0):
Epsilon for variance is set to 1e-5, as is torch/nn: https://github.com/torch/nn/blob/master/BatchNormalization.lua Epsilon for variance is set to 1e-5, as is torch/nn: https://github.com/torch/nn/blob/master/BatchNormalization.lua
x: BHWC tensor or a vector x: BHWC tensor or a vector
is_training: bool use_local_stat: bool. whether to use mean/var of this batch or the running
average. Usually set to True in training and False in testing
""" """
EPS = 1e-5
is_training = bool(is_training)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
if len(shape) == 2: if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, shape[1]]) x = tf.reshape(x, [-1, 1, 1, shape[1]])
...@@ -36,8 +38,9 @@ def BatchNorm(x, is_training=True, gamma_init=1.0): ...@@ -36,8 +38,9 @@ def BatchNorm(x, is_training=True, gamma_init=1.0):
n_out = shape[-1] # channel n_out = shape[-1] # channel
beta = tf.get_variable('beta', [n_out]) beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable(
initializer=tf.constant_initializer(gamma_init)) 'gamma', [n_out],
initializer=tf.constant_initializer(1.0))
# XXX hack to clear shape. see tensorflow#1162 # XXX hack to clear shape. see tensorflow#1162
if shape[0] is not None: if shape[0] is not None:
...@@ -48,16 +51,16 @@ def BatchNorm(x, is_training=True, gamma_init=1.0): ...@@ -48,16 +51,16 @@ def BatchNorm(x, is_training=True, gamma_init=1.0):
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.999) ema = tf.train.ExponentialMovingAverage(decay=decay)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
if is_training: if use_local_stat:
with tf.control_dependencies([ema_apply_op]): with tf.control_dependencies([ema_apply_op]):
return tf.nn.batch_norm_with_global_normalization( return tf.nn.batch_norm_with_global_normalization(
x, batch_mean, batch_var, beta, gamma, EPS, True) x, batch_mean, batch_var, beta, gamma, epsilon, True)
else: else:
batch = tf.cast(tf.shape(x)[0], tf.float32) batch = tf.cast(tf.shape(x)[0], tf.float32)
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_norm_with_global_normalization( return tf.nn.batch_norm_with_global_normalization(
x, mean, var, beta, gamma, EPS, True) x, mean, var, beta, gamma, epsilon, True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment