Commit dbd0d542 authored by Yuxin Wu's avatar Yuxin Wu

top5 error

parent d4f925a9
# tensorpack
Neural Network Toolbox on TensorFlow
In development. See [examples](https://github.com/ppwwyyxx/tensorpack/tree/master/examples) to learn.
In development. API might change a bit. See [examples](https://github.com/ppwwyyxx/tensorpack/tree/master/examples) to learn.
## Features:
+ Scoped abstraction of common models.
......
......@@ -105,11 +105,15 @@ class Model(ModelDesc):
for k in [cost, loss1, loss2, loss3]:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, k)
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
wrong = prediction_incorrect(logits, label, 1)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top1')
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error_top1'))
wrong = prediction_incorrect(logits, label, 5)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top5')
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error_top5'))
# weight decay on all W of fc layers
wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
......@@ -163,7 +167,9 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
InferenceRunner(dataset_val, ClassificationError()),
InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-top1-error'),
ClassificationError('wrong-top5', 'val-top5-error')]),
#HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt')
ScheduledHyperParamSetter('learning_rate',
[(8, 0.03), (13, 0.02), (21, 5e-3),
......
......@@ -162,20 +162,21 @@ class ClassificationError(Inferencer):
"""
Validate the accuracy from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch.
You can use `tf.nn.in_top_k` to record top-k error as well.
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def __init__(self, wrong_var_name='wrong:0', prefix='validation'):
def __init__(self, wrong_var_name='wrong:0', summary_name='validation_error'):
"""
:param wrong_var_name: name of the `wrong` variable
:param prefix: an optional prefix for logging
:param summary_name: an optional prefix for logging
"""
self.wrong_var_name = wrong_var_name
self.prefix = prefix
self.summary_name = summary_name
def _get_output_tensors(self):
return [self.wrong_var_name]
......@@ -190,6 +191,6 @@ class ClassificationError(Inferencer):
def _after_inference(self):
self.trainer.summary_writer.add_summary(
create_summary('{}_error'.format(self.prefix), self.err_stat.accuracy),
create_summary(self.summary_name, self.err_stat.accuracy),
get_global_step())
self.trainer.stat_holder.add_stat("{}_error".format(self.prefix), self.err_stat.accuracy)
self.trainer.stat_holder.add_stat(self.summary_name, self.err_stat.accuracy)
......@@ -21,18 +21,13 @@ def one_hot(y, num_labels):
onehot_labels.set_shape([None, num_labels])
return tf.cast(onehot_labels, tf.float32)
def prediction_incorrect(logits, label):
def prediction_incorrect(logits, label, topk=1):
"""
:param logits: NxC
:param label: N
:returns: a binary vector of length N with 1 meaning incorrect prediction
"""
with tf.op_scope([logits, label], 'incorrect'):
wrong = tf.not_equal(
tf.argmax(logits, 1),
tf.cast(label, tf.int64))
wrong = tf.cast(wrong, tf.float32)
return wrong
:returns: a float32 vector of length N with 0/1 values, 1 meaning incorrect prediction
"""
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)), tf.float32)
def flatten(x):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment