Commit 80622ae7 authored by Yuxin Wu's avatar Yuxin Wu

validation callback printer

parent 9fe18ff8
...@@ -20,7 +20,7 @@ from tensorpack.dataflow import * ...@@ -20,7 +20,7 @@ from tensorpack.dataflow import *
""" """
MNIST ConvNet example. MNIST ConvNet example.
99.3% validation accuracy after 50 epochs. 99.25% validation accuracy after 50 epochs.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -107,6 +107,7 @@ def get_config(): ...@@ -107,6 +107,7 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']),
ValidationError(dataset_test, prefix='validation'), ValidationError(dataset_test, prefix='validation'),
]), ]),
session_config=sess_config, session_config=sess_config,
......
...@@ -6,87 +6,102 @@ ...@@ -6,87 +6,102 @@
import tensorflow as tf import tensorflow as tf
import itertools import itertools
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta
from ..utils import * from ..utils import *
from ..utils.stat import * from ..utils.stat import *
from ..utils.summary import * from ..utils.summary import *
from .base import PeriodicCallback, Callback, TestCallback from .base import PeriodicCallback, Callback, TestCallback
__all__ = ['ValidationError', 'ValidationCallback'] __all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter']
class ValidationCallback(PeriodicCallback): class ValidationCallback(PeriodicCallback):
type = TestCallback() type = TestCallback()
""" """
Basic routine for validation callbacks. Base class for validation callbacks.
""" """
def __init__(self, ds, prefix, period=1, cost_var_name='cost:0'): def __init__(self, ds, prefix, period=1):
super(ValidationCallback, self).__init__(period) super(ValidationCallback, self).__init__(period)
self.ds = ds self.ds = ds
self.prefix = prefix self.prefix = prefix
self.cost_var_name = cost_var_name
def _before_train(self): def _before_train(self):
self.input_vars = self.trainer.model.reuse_input_vars() self.input_vars = self.trainer.model.reuse_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name)
self._find_output_vars() self._find_output_vars()
def get_tensor(self, name): def get_tensor(self, name):
return self.graph.get_tensor_by_name(name) return self.graph.get_tensor_by_name(name)
@abstractmethod
def _find_output_vars(self): def _find_output_vars(self):
pass """ prepare output variables. Will be called in before_train"""
@abstractmethod
def _get_output_vars(self): def _get_output_vars(self):
return [] """ return a list of output vars to eval"""
def _run_validation(self): def _run_validation(self):
""" """
Generator to return inputs and outputs Eval the vars, generate inputs and outputs
""" """
cnt = 0
cost_sum = 0
output_vars = self._get_output_vars() output_vars = self._get_output_vars()
output_vars.append(self.cost_var)
sess = tf.get_default_session() sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar: with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
feed = dict(itertools.izip(self.input_vars, dp)) feed = dict(itertools.izip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input batch_size = dp[0].shape[0] # assume batched input
cnt += batch_size
outputs = sess.run(output_vars, feed_dict=feed) outputs = sess.run(output_vars, feed_dict=feed)
cost = outputs[-1] yield (dp, outputs)
# each batch might not have the same size in validation
cost_sum += cost * batch_size
yield (dp, outputs[:-1])
pbar.update() pbar.update()
cost_avg = cost_sum / cnt @abstractmethod
self.trainer.summary_writer.add_summary(create_summary( def _trigger_periodic(self):
'{}_cost'.format(self.prefix), cost_avg), self.global_step) """ Implement the actual callback"""
self.trainer.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg)
class ValidationStatPrinter(ValidationCallback):
"""
Write stat and summary of some Op for a validation dataset.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
def __init__(self, ds, names_to_print, prefix='validation', period=1):
super(ValidationStatPrinter, self).__init__(ds, prefix, period)
self.names = names_to_print
def _find_output_vars(self):
self.vars_to_print = [self.get_tensor(n) for n in self.names]
def _get_output_vars(self):
return self.vars_to_print
def _trigger_periodic(self): def _trigger_periodic(self):
stats = []
for dp, outputs in self._run_validation(): for dp, outputs in self._run_validation():
pass stats.append(outputs)
stats = np.mean(stats, axis=0)
assert len(stats) == len(self.vars_to_print)
for stat, var in itertools.izip(stats, self.vars_to_print):
name = var.name.replace(':0', '')
self.trainer.summary_writer.add_summary(create_summary(
'{}_{}'.format(self.prefix, name), stat), self.global_step)
self.trainer.stat_holder.add_stat("{}_{}".format(self.prefix, name), stat)
class ValidationError(ValidationCallback): class ValidationError(ValidationCallback):
running_graph = 'test'
""" """
Validate the accuracy for the given wrong and cost variable Validate the accuracy from a 'wrong' variable
Use under the following setup: wrong_var: integer, number of failed samples in this batch
wrong_var: integer, number of failed samples in this batch ds: batched dataset
ds: batched dataset
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
""" """
def __init__(self, ds, prefix, def __init__(self, ds, prefix='validation',
period=1, period=1,
wrong_var_name='wrong:0', wrong_var_name='wrong:0'):
cost_var_name='cost:0'): super(ValidationError, self).__init__(ds, prefix, period)
super(ValidationError, self).__init__(
ds, prefix, period, cost_var_name)
self.wrong_var_name = wrong_var_name self.wrong_var_name = wrong_var_name
def _find_output_vars(self): def _find_output_vars(self):
......
...@@ -41,19 +41,20 @@ def add_param_summary(summary_lists): ...@@ -41,19 +41,20 @@ def add_param_summary(summary_lists):
""" """
def perform(var, action): def perform(var, action):
ndim = var.get_shape().ndims ndim = var.get_shape().ndims
name = var.name.replace(':0', '')
if action == 'scalar': if action == 'scalar':
assert ndim == 0, "Scalar summary on high-dimension data. Maybe you want 'mean'?" assert ndim == 0, "Scalar summary on high-dimension data. Maybe you want 'mean'?"
tf.scalar_summary(var.name, var) tf.scalar_summary(name, var)
return return
assert ndim > 0, "Cannot perform {} summary on scalar data".format(action) assert ndim > 0, "Cannot perform {} summary on scalar data".format(action)
if action == 'histogram': if action == 'histogram':
tf.histogram_summary(var.name, var) tf.histogram_summary(name, var)
return return
if action == 'sparsity': if action == 'sparsity':
tf.scalar_summary(var.name + '/sparsity', tf.nn.zero_fraction(var)) tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(var))
return return
if action == 'mean': if action == 'mean':
tf.scalar_summary(var.name + '/mean', tf.reduce_mean(var)) tf.scalar_summary(name + '/mean', tf.reduce_mean(var))
return return
raise RuntimeError("Unknown action {}".format(action)) raise RuntimeError("Unknown action {}".format(action))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment