Commit 76fe1b6b authored by Yuxin Wu's avatar Yuxin Wu

update cifar number & fix multigpu restore bug

parent da3da39d
...@@ -24,7 +24,7 @@ This implementation uses the variants proposed in: ...@@ -24,7 +24,7 @@ This implementation uses the variants proposed in:
Identity Mappings in Deep Residual Networks, arxiv:1603.05027 Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results for I can reproduce the results for
n=5, about 7.6% val error n=5, about 7.2% val error after 93k step with 2 TitanX (6.8it/s)
n=18, about 6.05% val error after 62k step with 2 TitanX (about 10hr) n=18, about 6.05% val error after 62k step with 2 TitanX (about 10hr)
n=30: a 182-layer network, about 5.5% val error after 51k step with 2 GPUs n=30: a 182-layer network, about 5.5% val error after 51k step with 2 GPUs
This model uses the whole training set instead of a 95:5 train-val split. This model uses the whole training set instead of a 95:5 train-val split.
......
...@@ -20,8 +20,9 @@ from tensorpack.dataflow import imgaug ...@@ -20,8 +20,9 @@ from tensorpack.dataflow import imgaug
""" """
Reach 1.9% validation error after 90 epochs, with 2 GPUs. ResNet-110 for SVHN Digit Classification.
You might need to adjust learning rate schedule when running with 1 GPU. Reach 1.9% validation error after 90 epochs, with 2 TitanX xxhr, 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -98,8 +99,7 @@ class Model(ModelDesc): ...@@ -98,8 +99,7 @@ class Model(ModelDesc):
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 10) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
...@@ -167,8 +167,8 @@ def get_config(): ...@@ -167,8 +167,8 @@ def get_config():
optimizer=tf.train.MomentumOptimizer(lr, 0.9), optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), ModelSaver(),
ValidationError(dataset_test, prefix='test'), ClassificationError(dataset_test, prefix='validation'),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)]) [(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)])
]), ]),
......
...@@ -114,7 +114,7 @@ def get_config(): ...@@ -114,7 +114,7 @@ def get_config():
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-2, learning_rate=1e-2,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=step_per_epoch * 30 if nr_gpu == 1 else 20, decay_steps=step_per_epoch * (30 if nr_gpu == 1 else 20),
decay_rate=0.5, staircase=True, name='learning_rate') decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
...@@ -129,7 +129,7 @@ def get_config(): ...@@ -129,7 +129,7 @@ def get_config():
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=3, max_epoch=20,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -26,11 +26,12 @@ class ModelSaver(Callback): ...@@ -26,11 +26,12 @@ class ModelSaver(Callback):
def _before_train(self): def _before_train(self):
self.path = os.path.join(logger.LOG_DIR, 'model') self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
var_list=self._get_vars(), var_list=ModelSaver._get_vars(),
max_to_keep=self.keep_recent, max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq) keep_checkpoint_every_n_hours=self.keep_freq)
def _get_vars(self): @staticmethod
def _get_vars():
vars = tf.all_variables() vars = tf.all_variables()
var_dict = {} var_dict = {}
for v in vars: for v in vars:
......
...@@ -70,6 +70,7 @@ class TestCallbackContext(object): ...@@ -70,6 +70,7 @@ class TestCallbackContext(object):
with create_test_session(trainer) as sess: with create_test_session(trainer) as sess:
self.sess = sess self.sess = sess
self.graph = sess.graph self.graph = sess.graph
# no tower in test graph. just keep it as what it is
self.saver = tf.train.Saver() self.saver = tf.train.Saver()
with self.graph.as_default(), self.sess.as_default(): with self.graph.as_default(), self.sess.as_default():
yield yield
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
import os import os
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
from collections import defaultdict
import re
import tensorflow as tf import tensorflow as tf
import six import six
...@@ -38,7 +40,7 @@ class NewSession(SessionInit): ...@@ -38,7 +40,7 @@ class NewSession(SessionInit):
class SaverRestore(SessionInit): class SaverRestore(SessionInit):
""" """
Restore an old model saved by `tf.Saver`. Restore an old model saved by `ModelSaver`.
""" """
def __init__(self, model_path): def __init__(self, model_path):
""" """
...@@ -52,14 +54,60 @@ class SaverRestore(SessionInit): ...@@ -52,14 +54,60 @@ class SaverRestore(SessionInit):
self.set_path(model_path) self.set_path(model_path)
def _init(self, sess): def _init(self, sess):
saver = tf.train.Saver()
saver.restore(sess, self.path)
logger.info( logger.info(
"Restore checkpoint from {}".format(self.path)) "Restoring checkpoint from {}.".format(self.path))
sess.run(tf.initialize_all_variables())
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map):
saver = tf.train.Saver(var_list=dic)
saver.restore(sess, self.path)
def set_path(self, model_path): def set_path(self, model_path):
self.path = model_path self.path = model_path
@staticmethod
def _produce_restore_dict(vars_multimap):
"""
Produce {var_name: var} dict that can be used by `tf.train.Saver`, from a {var_name: [vars]} dict.
"""
while len(vars_multimap):
ret = {}
for k in vars_multimap.keys():
v = vars_multimap[k]
ret[k] = v[-1]
del v[-1]
if not len(v):
del vars_multimap[k]
yield ret
@staticmethod
def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path)
return set(reader.GetVariableToShapeMap().keys())
@staticmethod
def _get_vars_to_restore_multimap(vars_available):
"""
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking
"""
# TODO warn if some variable in checkpoint is not used
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list)
for v in vars_to_restore:
name = v.op.name
if 'tower' in name:
new_name = re.sub('tower[0-9]+/', '', name)
name = new_name
if name in vars_available:
var_dict[name].append(v)
else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
return var_dict
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
""" """
Restore trainable variables from a dictionary. Restore trainable variables from a dictionary.
......
...@@ -83,7 +83,7 @@ class Trainer(object): ...@@ -83,7 +83,7 @@ class Trainer(object):
self.global_step = get_global_step() self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step)) logger.info("Start training with global_step={}".format(self.global_step))
for epoch in range(self.config.starting_epoch, self.config.max_epoch): for epoch in range(self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation( with timed_operation(
'Epoch {}, global_step={}'.format( 'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)): epoch, self.global_step + self.config.step_per_epoch)):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment