Commit 7207816d authored by Yuxin Wu's avatar Yuxin Wu

fix bug in svhn

parent 48a19f6d
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: svhn_digit_convnet.py # File: svhn-digit-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
...@@ -71,8 +71,8 @@ def get_config(): ...@@ -71,8 +71,8 @@ def get_config():
# prepare dataset # prepare dataset
d1 = dataset.SVHNDigit('train') d1 = dataset.SVHNDigit('train')
d2 = dataset.SVHNDigit('extra') d2 = dataset.SVHNDigit('extra')
train = RandomMixData([d1, d2]) data_train = RandomMixData([d1, d2])
test = dataset.SVHNDigit('test') data_test = dataset.SVHNDigit('test')
augmentors = [ augmentors = [
imgaug.Resize((40, 40)), imgaug.Resize((40, 40)),
...@@ -82,37 +82,34 @@ def get_config(): ...@@ -82,37 +82,34 @@ def get_config():
[(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)], [(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
(40,40), 0.2, 3), (40,40), 0.2, 3),
] ]
train = AugmentImageComponent(train, augmentors) data_train = AugmentImageComponent(data_train, augmentors)
train = BatchData(train, 128) data_train = BatchData(data_train, 128)
nr_proc = 5 nr_proc = 5
train = PrefetchData(train, 5, nr_proc) data_train = PrefetchData(data_train, 5, nr_proc)
step_per_epoch = train.size() step_per_epoch = data_train.size()
augmentors = [ augmentors = [
imgaug.Resize((40, 40)), imgaug.Resize((40, 40)),
] ]
test = AugmentImageComponent(test, augmentors) data_test = AugmentImageComponent(data_test, augmentors)
test = BatchData(test, 128, remainder=True) data_test = BatchData(data_test, 128, remainder=True)
sess_config = get_default_sess_config(0.8)
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-3, learning_rate=1e-3,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=train.size() * 60, decay_steps=data_train.size() * 60,
decay_rate=0.2, staircase=True, name='learning_rate') decay_rate=0.2, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
dataset=train, dataset=data_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
InferenceRunner(dataset_test, InferenceRunner(data_test,
[ScalarStats('cost'), ClassificationError()]) [ScalarStats('cost'), ClassificationError()])
]), ]),
session_config=sess_config,
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=350, max_epoch=350,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment