Commit 3754faac authored by Yuxin Wu's avatar Yuxin Wu

cifar with plateu detection

parent 5aab2d2d
......@@ -40,7 +40,7 @@ EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 * 4 bytes = 26G memory.
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
# Suggest using tcmalloc to manage memory space better.
INIT_MEMORY_SIZE = 5e4
STEP_PER_EPOCH = 10000
......
......@@ -15,8 +15,7 @@ from tensorpack.tfutils.summary import *
A small convnet model for Cifar10 or Cifar100 dataset.
Cifar10:
90% validation accuracy after 40k step.
91% accuracy after 80k step.
91% accuracy after 50k step.
19.3 step/s on Tesla M40
Not a good model for Cifar100, just for demonstration.
......@@ -66,7 +65,7 @@ class Model(ModelDesc):
add_moving_summary(tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(0.004,
wd_cost = tf.mul(0.0004,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
add_moving_summary(cost, wd_cost)
......@@ -112,26 +111,27 @@ def get_config(cifar_classnum):
sess_config = get_default_sess_config(0.5)
nr_gpu = get_nr_gpu()
lr = tf.train.exponential_decay(
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=step_per_epoch * (30 if nr_gpu == 1 else 20),
decay_rate=0.5, staircase=True, name='learning_rate')
lr = tf.Variable(1e-2, name='learning_rate',
dtype=tf.float32, trainable=False)
tf.scalar_summary('learning_rate', lr)
def lr_func(lr):
if lr < 3e-5:
raise StopTraining()
return lr * 0.31
return TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
InferenceRunner(dataset_test, ClassificationError())
StatPrinter(), ModelSaver(),
InferenceRunner(dataset_test, ClassificationError()),
StatMonitorParamSetter('learning_rate', 'val_error', lr_func,
threshold=0.001, last_k=10),
]),
session_config=sess_config,
model=Model(cifar_classnum),
step_per_epoch=step_per_epoch,
max_epoch=250,
max_epoch=150,
)
if __name__ == '__main__':
......
......@@ -8,7 +8,7 @@ import os
import time
from abc import abstractmethod, ABCMeta
__all__ = ['Callback', 'PeriodicCallback']
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
class Callback(object):
""" Base class for all callbacks """
......
......@@ -18,7 +18,10 @@ from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.summary import create_summary
__all__ = ['Trainer']
__all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
pass
class Trainer(object):
"""
......@@ -138,6 +141,8 @@ class Trainer(object):
#callbacks.trigger_step() # not useful?
self.global_step += 1
self.trigger_epoch()
except StopTraining:
logger.info("Training was stopped.")
except (KeyboardInterrupt, Exception):
raise
finally:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment