Commit 80622ae7 authored by Yuxin Wu's avatar Yuxin Wu

validation callback printer

parent 9fe18ff8
......@@ -20,7 +20,7 @@ from tensorpack.dataflow import *
"""
MNIST ConvNet example.
99.3% validation accuracy after 50 epochs.
99.25% validation accuracy after 50 epochs.
"""
BATCH_SIZE = 128
......@@ -107,6 +107,7 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']),
ValidationError(dataset_test, prefix='validation'),
]),
session_config=sess_config,
......
......@@ -6,87 +6,102 @@
import tensorflow as tf
import itertools
from tqdm import tqdm
from abc import ABCMeta
from ..utils import *
from ..utils.stat import *
from ..utils.summary import *
from .base import PeriodicCallback, Callback, TestCallback
__all__ = ['ValidationError', 'ValidationCallback']
__all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter']
class ValidationCallback(PeriodicCallback):
type = TestCallback()
"""
Basic routine for validation callbacks.
Base class for validation callbacks.
"""
def __init__(self, ds, prefix, period=1, cost_var_name='cost:0'):
def __init__(self, ds, prefix, period=1):
super(ValidationCallback, self).__init__(period)
self.ds = ds
self.prefix = prefix
self.cost_var_name = cost_var_name
def _before_train(self):
self.input_vars = self.trainer.model.reuse_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name)
self._find_output_vars()
def get_tensor(self, name):
return self.graph.get_tensor_by_name(name)
@abstractmethod
def _find_output_vars(self):
pass
""" prepare output variables. Will be called in before_train"""
@abstractmethod
def _get_output_vars(self):
return []
""" return a list of output vars to eval"""
def _run_validation(self):
"""
Generator to return inputs and outputs
Eval the vars, generate inputs and outputs
"""
cnt = 0
cost_sum = 0
output_vars = self._get_output_vars()
output_vars.append(self.cost_var)
sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
feed = dict(itertools.izip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input
cnt += batch_size
outputs = sess.run(output_vars, feed_dict=feed)
cost = outputs[-1]
# each batch might not have the same size in validation
cost_sum += cost * batch_size
yield (dp, outputs[:-1])
yield (dp, outputs)
pbar.update()
cost_avg = cost_sum / cnt
self.trainer.summary_writer.add_summary(create_summary(
'{}_cost'.format(self.prefix), cost_avg), self.global_step)
self.trainer.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg)
@abstractmethod
def _trigger_periodic(self):
""" Implement the actual callback"""
class ValidationStatPrinter(ValidationCallback):
"""
Write stat and summary of some Op for a validation dataset.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
def __init__(self, ds, names_to_print, prefix='validation', period=1):
super(ValidationStatPrinter, self).__init__(ds, prefix, period)
self.names = names_to_print
def _find_output_vars(self):
self.vars_to_print = [self.get_tensor(n) for n in self.names]
def _get_output_vars(self):
return self.vars_to_print
def _trigger_periodic(self):
stats = []
for dp, outputs in self._run_validation():
pass
stats.append(outputs)
stats = np.mean(stats, axis=0)
assert len(stats) == len(self.vars_to_print)
for stat, var in itertools.izip(stats, self.vars_to_print):
name = var.name.replace(':0', '')
self.trainer.summary_writer.add_summary(create_summary(
'{}_{}'.format(self.prefix, name), stat), self.global_step)
self.trainer.stat_holder.add_stat("{}_{}".format(self.prefix, name), stat)
class ValidationError(ValidationCallback):
running_graph = 'test'
"""
Validate the accuracy for the given wrong and cost variable
Use under the following setup:
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
Validate the accuracy from a 'wrong' variable
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def __init__(self, ds, prefix,
def __init__(self, ds, prefix='validation',
period=1,
wrong_var_name='wrong:0',
cost_var_name='cost:0'):
super(ValidationError, self).__init__(
ds, prefix, period, cost_var_name)
wrong_var_name='wrong:0'):
super(ValidationError, self).__init__(ds, prefix, period)
self.wrong_var_name = wrong_var_name
def _find_output_vars(self):
......
......@@ -41,19 +41,20 @@ def add_param_summary(summary_lists):
"""
def perform(var, action):
ndim = var.get_shape().ndims
name = var.name.replace(':0', '')
if action == 'scalar':
assert ndim == 0, "Scalar summary on high-dimension data. Maybe you want 'mean'?"
tf.scalar_summary(var.name, var)
tf.scalar_summary(name, var)
return
assert ndim > 0, "Cannot perform {} summary on scalar data".format(action)
if action == 'histogram':
tf.histogram_summary(var.name, var)
tf.histogram_summary(name, var)
return
if action == 'sparsity':
tf.scalar_summary(var.name + '/sparsity', tf.nn.zero_fraction(var))
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(var))
return
if action == 'mean':
tf.scalar_summary(var.name + '/mean', tf.reduce_mean(var))
tf.scalar_summary(name + '/mean', tf.reduce_mean(var))
return
raise RuntimeError("Unknown action {}".format(action))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment