Commit 5a8d500c authored by Yuxin Wu's avatar Yuxin Wu

ChainInit & verbose download error

parent 7e32ccc7
......@@ -8,3 +8,6 @@ The train error shown here is a moving average of the error rate of each batch i
The validation error here is computed on test set.
![cifar10](https://github.com/ppwwyyxx/tensorpack/raw/master/examples/ResNet/cifar10-resnet.png)
Download model:
[Cifar10 n=18](https://drive.google.com/open?id=0B308TeQzmFDLeHpSaHAxWGV1WDg)
......@@ -178,7 +178,7 @@ class ClassificationError(Inferencer):
return [self.wrong_var_name]
def _before_inference(self):
self.err_stat = Accuracy()
self.err_stat = RatioCounter()
def _datapoint(self, dp, outputs):
batch_size = dp[0].shape[0] # assume batched input
......@@ -186,7 +186,7 @@ class ClassificationError(Inferencer):
self.err_stat.feed(wrong, batch_size)
def _after_inference(self):
self.trainer.write_scalar_summary(self.summary_name, self.err_stat.accuracy)
self.trainer.write_scalar_summary(self.summary_name, self.err_stat.ratio)
class BinaryClassificationStats(Inferencer):
""" Compute precision/recall in binary classification, given the
......
......@@ -50,7 +50,8 @@ class ILSVRCMeta(object):
proto_path = download(CAFFE_PROTO_URL, self.dir)
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(self.dir))
assert ret == 0, "caffe proto compilation failed!"
assert ret == 0, \
"caffe proto compilation failed! Did you install protoc?"
def get_image_list(self, name):
"""
......
......@@ -13,7 +13,7 @@ import six
from ..utils import logger
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore',
'ParamRestore', 'ChainInit',
'JustCurrentSession',
'dump_session_params']
......@@ -65,7 +65,6 @@ class SaverRestore(SessionInit):
def _init(self, sess):
logger.info(
"Restoring checkpoint from {}.".format(self.path))
sess.run(tf.initialize_all_variables())
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map):
......@@ -131,7 +130,6 @@ class ParamRestore(SessionInit):
self.prms = param_dict
def _init(self, sess):
sess.run(tf.initialize_all_variables())
# allow restore non-trainable variables
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
var_dict = dict([v.name, v] for v in variables)
......@@ -152,6 +150,21 @@ class ParamRestore(SessionInit):
value = value.reshape(varshape)
sess.run(var.assign(value))
def ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance."""
def __init__(self, sess_inits, new_session=True):
"""
:params sess_inits: list of `SessionInit` instances.
:params new_session: add a `NewSession()` and the beginning, if not there
"""
if new_session and not isinstance(sess_inits[0], NewSession):
sess_inits.insert(0, NewSession())
self.inits = sess_inits
def _init(self, sess):
for i in self.inits:
i.init(sess)
def dump_session_params(path):
""" Dump value of all trainable variables to a dict and save to `path` as
npy format, loadable by ParamRestore
......
......@@ -105,6 +105,7 @@ class Trainer(object):
get_global_step_var() # ensure there is such var, before finalizing the graph
callbacks = self.config.callbacks
callbacks.setup_graph(self)
self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
self._start_concurrency()
......
......@@ -47,7 +47,7 @@ class TrainConfig(object):
self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession())
self.session_init = kwargs.pop('session_init', JustCurrentSession())
assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch'))
self.starting_epoch = int(kwargs.pop('starting_epoch', 1))
......
......@@ -27,10 +27,16 @@ def download(url, dir):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(fname, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
try:
fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=_progress)
statinfo = os.stat(fpath)
size = statinfo.st_size
except:
logger.error("Failed to download {}".format(url))
raise
assert size > 0, "Download an empty file!"
sys.stdout.write('\n')
print('Succesfully downloaded ' + fname + " " + str(statinfo.st_size) + ' bytes.')
print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.')
return fpath
if __name__ == '__main__':
......
......@@ -6,6 +6,7 @@ import numpy as np
__all__ = ['StatCounter', 'Accuracy', 'BinaryStatistics', 'RatioCounter']
class StatCounter(object):
""" A simple counter"""
def __init__(self):
self.reset()
......@@ -35,6 +36,7 @@ class StatCounter(object):
return max(self._values)
class RatioCounter(object):
""" A counter to count ratio of something"""
def __init__(self):
self.reset()
......@@ -57,6 +59,7 @@ class RatioCounter(object):
return self._tot
class Accuracy(RatioCounter):
""" A RatioCounter with a fancy name """
@property
def accuracy(self):
return self.ratio
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment