Commit 3754faac authored by Yuxin Wu's avatar Yuxin Wu

cifar with plateu detection

parent 5aab2d2d
...@@ -40,7 +40,7 @@ EXPLORATION_EPOCH_ANNEAL = 0.01 ...@@ -40,7 +40,7 @@ EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION = 0.1 END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 * 4 bytes = 26G memory. # NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
# Suggest using tcmalloc to manage memory space better. # Suggest using tcmalloc to manage memory space better.
INIT_MEMORY_SIZE = 5e4 INIT_MEMORY_SIZE = 5e4
STEP_PER_EPOCH = 10000 STEP_PER_EPOCH = 10000
......
...@@ -15,8 +15,7 @@ from tensorpack.tfutils.summary import * ...@@ -15,8 +15,7 @@ from tensorpack.tfutils.summary import *
A small convnet model for Cifar10 or Cifar100 dataset. A small convnet model for Cifar10 or Cifar100 dataset.
Cifar10: Cifar10:
90% validation accuracy after 40k step. 91% accuracy after 50k step.
91% accuracy after 80k step.
19.3 step/s on Tesla M40 19.3 step/s on Tesla M40
Not a good model for Cifar100, just for demonstration. Not a good model for Cifar100, just for demonstration.
...@@ -66,7 +65,7 @@ class Model(ModelDesc): ...@@ -66,7 +65,7 @@ class Model(ModelDesc):
add_moving_summary(tf.reduce_mean(wrong, name='train_error')) add_moving_summary(tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = tf.mul(0.004, wd_cost = tf.mul(0.0004,
regularize_cost('fc.*/W', tf.nn.l2_loss), regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss') name='regularize_loss')
add_moving_summary(cost, wd_cost) add_moving_summary(cost, wd_cost)
...@@ -112,26 +111,27 @@ def get_config(cifar_classnum): ...@@ -112,26 +111,27 @@ def get_config(cifar_classnum):
sess_config = get_default_sess_config(0.5) sess_config = get_default_sess_config(0.5)
nr_gpu = get_nr_gpu() lr = tf.Variable(1e-2, name='learning_rate',
lr = tf.train.exponential_decay( dtype=tf.float32, trainable=False)
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=step_per_epoch * (30 if nr_gpu == 1 else 20),
decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
def lr_func(lr):
if lr < 3e-5:
raise StopTraining()
return lr * 0.31
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(), ModelSaver(),
ModelSaver(), InferenceRunner(dataset_test, ClassificationError()),
InferenceRunner(dataset_test, ClassificationError()) StatMonitorParamSetter('learning_rate', 'val_error', lr_func,
threshold=0.001, last_k=10),
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(cifar_classnum), model=Model(cifar_classnum),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=250, max_epoch=150,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
import time import time
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
__all__ = ['Callback', 'PeriodicCallback'] __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
class Callback(object): class Callback(object):
""" Base class for all callbacks """ """ Base class for all callbacks """
......
...@@ -18,7 +18,10 @@ from ..callbacks import StatHolder ...@@ -18,7 +18,10 @@ from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.summary import create_summary from ..tfutils.summary import create_summary
__all__ = ['Trainer'] __all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
pass
class Trainer(object): class Trainer(object):
""" """
...@@ -138,6 +141,8 @@ class Trainer(object): ...@@ -138,6 +141,8 @@ class Trainer(object):
#callbacks.trigger_step() # not useful? #callbacks.trigger_step() # not useful?
self.global_step += 1 self.global_step += 1
self.trigger_epoch() self.trigger_epoch()
except StopTraining:
logger.info("Training was stopped.")
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
raise raise
finally: finally:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment