Commit 76fe1b6b authored by Yuxin Wu's avatar Yuxin Wu

update cifar number & fix multigpu restore bug

parent da3da39d
......@@ -24,7 +24,7 @@ This implementation uses the variants proposed in:
Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results for
n=5, about 7.6% val error
n=5, about 7.2% val error after 93k step with 2 TitanX (6.8it/s)
n=18, about 6.05% val error after 62k step with 2 TitanX (about 10hr)
n=30: a 182-layer network, about 5.5% val error after 51k step with 2 GPUs
This model uses the whole training set instead of a 95:5 train-val split.
......
......@@ -20,8 +20,9 @@ from tensorpack.dataflow import imgaug
"""
Reach 1.9% validation error after 90 epochs, with 2 GPUs.
You might need to adjust learning rate schedule when running with 1 GPU.
ResNet-110 for SVHN Digit Classification.
Reach 1.9% validation error after 90 epochs, with 2 TitanX xxhr, 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU.
"""
BATCH_SIZE = 128
......@@ -98,8 +99,7 @@ class Model(ModelDesc):
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
......@@ -167,8 +167,8 @@ def get_config():
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
ModelSaver(),
ClassificationError(dataset_test, prefix='validation'),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)])
]),
......
......@@ -114,7 +114,7 @@ def get_config():
lr = tf.train.exponential_decay(
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=step_per_epoch * 30 if nr_gpu == 1 else 20,
decay_steps=step_per_epoch * (30 if nr_gpu == 1 else 20),
decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
......@@ -129,7 +129,7 @@ def get_config():
session_config=sess_config,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=3,
max_epoch=20,
)
if __name__ == '__main__':
......
......@@ -26,11 +26,12 @@ class ModelSaver(Callback):
def _before_train(self):
self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver(
var_list=self._get_vars(),
var_list=ModelSaver._get_vars(),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
def _get_vars(self):
@staticmethod
def _get_vars():
vars = tf.all_variables()
var_dict = {}
for v in vars:
......
......@@ -70,6 +70,7 @@ class TestCallbackContext(object):
with create_test_session(trainer) as sess:
self.sess = sess
self.graph = sess.graph
# no tower in test graph. just keep it as what it is
self.saver = tf.train.Saver()
with self.graph.as_default(), self.sess.as_default():
yield
......
......@@ -5,6 +5,8 @@
import os
from abc import abstractmethod, ABCMeta
import numpy as np
from collections import defaultdict
import re
import tensorflow as tf
import six
......@@ -38,7 +40,7 @@ class NewSession(SessionInit):
class SaverRestore(SessionInit):
"""
Restore an old model saved by `tf.Saver`.
Restore an old model saved by `ModelSaver`.
"""
def __init__(self, model_path):
"""
......@@ -52,14 +54,60 @@ class SaverRestore(SessionInit):
self.set_path(model_path)
def _init(self, sess):
saver = tf.train.Saver()
saver.restore(sess, self.path)
logger.info(
"Restore checkpoint from {}".format(self.path))
"Restoring checkpoint from {}.".format(self.path))
sess.run(tf.initialize_all_variables())
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map):
saver = tf.train.Saver(var_list=dic)
saver.restore(sess, self.path)
def set_path(self, model_path):
self.path = model_path
@staticmethod
def _produce_restore_dict(vars_multimap):
"""
Produce {var_name: var} dict that can be used by `tf.train.Saver`, from a {var_name: [vars]} dict.
"""
while len(vars_multimap):
ret = {}
for k in vars_multimap.keys():
v = vars_multimap[k]
ret[k] = v[-1]
del v[-1]
if not len(v):
del vars_multimap[k]
yield ret
@staticmethod
def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path)
return set(reader.GetVariableToShapeMap().keys())
@staticmethod
def _get_vars_to_restore_multimap(vars_available):
"""
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking
"""
# TODO warn if some variable in checkpoint is not used
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list)
for v in vars_to_restore:
name = v.op.name
if 'tower' in name:
new_name = re.sub('tower[0-9]+/', '', name)
name = new_name
if name in vars_available:
var_dict[name].append(v)
else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
return var_dict
class ParamRestore(SessionInit):
"""
Restore trainable variables from a dictionary.
......
......@@ -83,7 +83,7 @@ class Trainer(object):
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
for epoch in range(self.config.starting_epoch, self.config.max_epoch):
for epoch in range(self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment