Commit 565404ec authored by Yuxin Wu's avatar Yuxin Wu

update a cifar config

parent 701628b4
......@@ -18,7 +18,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
"""
CIFAR10 89% test accuracy after 60k step (about 150 epochs)
CIFAR10 90% validation accuracy after 100k step, 91% after 160k step.
"""
BATCH_SIZE = 128
......@@ -91,6 +91,7 @@ class Model(ModelDesc):
return tf.add_n([cost, wd_cost], name='cost')
def get_config():
#anchors = np.mgrid[0:4,0:4][:,1:,1:].transpose(1,2,0).reshape((-1,2)) / 4.0
# prepare dataset
dataset_train = dataset.Cifar10('train')
augmentors = [
......@@ -98,14 +99,15 @@ def get_config():
imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(63),
imgaug.Contrast((0.2,1.8)),
#imgaug.GaussianDeform([(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
#(30,30), 0.2, 3),
imgaug.GaussianDeform(
[(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
(30,30), 0.2, 3),
imgaug.MeanVarianceNormalize(all_channel=True)
]
dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128)
#dataset_train = PrefetchData(dataset_train, 3, 2)
step_per_epoch = dataset_train.size()
dataset_train = PrefetchData(dataset_train, 3, 2)
step_per_epoch = dataset_train.size() / 2
augmentors = [
imgaug.CenterCrop((30, 30)),
......@@ -121,8 +123,8 @@ def get_config():
lr = tf.train.exponential_decay(
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 40,
decay_rate=0.4, staircase=True, name='learning_rate')
decay_steps=dataset_train.size() * 30,
decay_rate=0.7, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
......@@ -136,7 +138,7 @@ def get_config():
session_config=sess_config,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=300,
max_epoch=500,
)
if __name__ == '__main__':
......
......@@ -45,7 +45,7 @@ def add_param_summary(regex):
if p.get_shape().ndims == 0:
tf.scalar_summary(name, p)
else:
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(p))
#tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(p))
tf.histogram_summary(name, p)
def summary_moving_average(cost_var):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment