Commit dbd0d542 authored by Yuxin Wu's avatar Yuxin Wu

top5 error

parent d4f925a9
# tensorpack # tensorpack
Neural Network Toolbox on TensorFlow Neural Network Toolbox on TensorFlow
In development. See [examples](https://github.com/ppwwyyxx/tensorpack/tree/master/examples) to learn. In development. API might change a bit. See [examples](https://github.com/ppwwyyxx/tensorpack/tree/master/examples) to learn.
## Features: ## Features:
+ Scoped abstraction of common models. + Scoped abstraction of common models.
......
...@@ -105,11 +105,15 @@ class Model(ModelDesc): ...@@ -105,11 +105,15 @@ class Model(ModelDesc):
for k in [cost, loss1, loss2, loss3]: for k in [cost, loss1, loss2, loss3]:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, k) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, k)
wrong = prediction_incorrect(logits, label) wrong = prediction_incorrect(logits, label, 1)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong-top1')
# monitor training error
tf.add_to_collection( tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error')) MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error_top1'))
wrong = prediction_incorrect(logits, label, 5)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top5')
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error_top5'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(), wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
...@@ -163,7 +167,9 @@ def get_config(): ...@@ -163,7 +167,9 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
InferenceRunner(dataset_val, ClassificationError()), InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-top1-error'),
ClassificationError('wrong-top5', 'val-top5-error')]),
#HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt') #HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt')
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(8, 0.03), (13, 0.02), (21, 5e-3), [(8, 0.03), (13, 0.02), (21, 5e-3),
......
...@@ -162,20 +162,21 @@ class ClassificationError(Inferencer): ...@@ -162,20 +162,21 @@ class ClassificationError(Inferencer):
""" """
Validate the accuracy from a `wrong` variable Validate the accuracy from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch.
You can use `tf.nn.in_top_k` to record top-k error as well.
This callback produce the "true" error, This callback produce the "true" error,
taking account of the fact that batches might not have the same size in taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size). testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter. In theory, the result could be different from what produced by ValidationStatPrinter.
""" """
def __init__(self, wrong_var_name='wrong:0', prefix='validation'): def __init__(self, wrong_var_name='wrong:0', summary_name='validation_error'):
""" """
:param wrong_var_name: name of the `wrong` variable :param wrong_var_name: name of the `wrong` variable
:param prefix: an optional prefix for logging :param summary_name: an optional prefix for logging
""" """
self.wrong_var_name = wrong_var_name self.wrong_var_name = wrong_var_name
self.prefix = prefix self.summary_name = summary_name
def _get_output_tensors(self): def _get_output_tensors(self):
return [self.wrong_var_name] return [self.wrong_var_name]
...@@ -190,6 +191,6 @@ class ClassificationError(Inferencer): ...@@ -190,6 +191,6 @@ class ClassificationError(Inferencer):
def _after_inference(self): def _after_inference(self):
self.trainer.summary_writer.add_summary( self.trainer.summary_writer.add_summary(
create_summary('{}_error'.format(self.prefix), self.err_stat.accuracy), create_summary(self.summary_name, self.err_stat.accuracy),
get_global_step()) get_global_step())
self.trainer.stat_holder.add_stat("{}_error".format(self.prefix), self.err_stat.accuracy) self.trainer.stat_holder.add_stat(self.summary_name, self.err_stat.accuracy)
...@@ -21,18 +21,13 @@ def one_hot(y, num_labels): ...@@ -21,18 +21,13 @@ def one_hot(y, num_labels):
onehot_labels.set_shape([None, num_labels]) onehot_labels.set_shape([None, num_labels])
return tf.cast(onehot_labels, tf.float32) return tf.cast(onehot_labels, tf.float32)
def prediction_incorrect(logits, label): def prediction_incorrect(logits, label, topk=1):
""" """
:param logits: NxC :param logits: NxC
:param label: N :param label: N
:returns: a binary vector of length N with 1 meaning incorrect prediction :returns: a float32 vector of length N with 0/1 values, 1 meaning incorrect prediction
""" """
with tf.op_scope([logits, label], 'incorrect'): return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)), tf.float32)
wrong = tf.not_equal(
tf.argmax(logits, 1),
tf.cast(label, tf.int64))
wrong = tf.cast(wrong, tf.float32)
return wrong
def flatten(x): def flatten(x):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment