Commit 174c3fc9 authored by Yuxin Wu's avatar Yuxin Wu

single-pass inference

parent 76fe1b6b
...@@ -162,7 +162,8 @@ def get_config(): ...@@ -162,7 +162,8 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
ClassificationError(dataset_test, prefix='validation'), InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]), ]),
......
...@@ -168,7 +168,8 @@ def get_config(): ...@@ -168,7 +168,8 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
ClassificationError(dataset_test, prefix='validation'), InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError() ]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)]) [(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)])
]), ]),
......
...@@ -124,12 +124,12 @@ def get_config(): ...@@ -124,12 +124,12 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
ClassificationError(dataset_test, prefix='test'), InferenceRunner(dataset_test, ClassificationError())
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=20, max_epoch=300,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -72,7 +72,7 @@ class Model(ModelDesc): ...@@ -72,7 +72,7 @@ class Model(ModelDesc):
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram', 'sparsity'])]) # monitor histogram of all W add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost') return tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
...@@ -102,8 +102,8 @@ def get_config(): ...@@ -102,8 +102,8 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']), InferenceRunner(dataset_test,
ClassificationError(dataset_test, prefix='validation'), [ScalarStats('cost'), ClassificationError() ])
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
......
...@@ -109,7 +109,8 @@ def get_config(): ...@@ -109,7 +109,8 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
ClassificationError(test, prefix='test'), InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()])
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
......
# -*- coding: UTF-8 -*-
# File: inference.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from six.moves import zip
from ..dataflow import DataFlow
from ..utils import *
from ..utils.stat import *
from ..tfutils import *
from ..tfutils.summary import *
from .base import Callback, TestCallbackType
__all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer']
class Inferencer(object):
__metaclass__ = ABCMeta
def before_inference(self):
"""
Called before a new round of inference starts.
"""
self._before_inference()
def _before_inference(self):
pass
def datapoint(self, dp, output):
"""
Called after complete running every data point
"""
self._datapoint(dp, output)
@abstractmethod
def _datapoint(self, dp, output):
pass
def after_inference(self):
"""
Called after a round of inference ends.
"""
self._after_inference()
def _after_inference(self):
pass
def get_output_tensors(self):
"""
Return a list of tensor names needed for this inference
"""
return self._get_output_vars()
@abstractmethod
def _get_output_tensors(self):
pass
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
type = TestCallbackType()
def __init__(self, ds, vcs):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param vcs: a list of `Inferencer` instance.
"""
assert isinstance(ds, DataFlow), type(ds)
self.ds = ds
if not isinstance(vcs, list):
self.vcs = [vcs]
else:
self.vcs = vcs
for v in self.vcs:
assert isinstance(v, Inferencer), str(v)
def _before_train(self):
self.input_vars = self.trainer.model.reuse_input_vars()
self._find_output_tensors()
for v in self.vcs:
v.trainer = self.trainer
def _find_output_tensors(self):
self.output_tensors = []
self.vc_to_vars = []
for vc in self.vcs:
vc_vars = vc._get_output_tensors()
def find_oid(var):
if var in self.output_tensors:
return self.output_tensors.index(var)
else:
self.output_tensors.append(var)
return len(self.output_tensors) - 1
vc_vars = [(var, find_oid(var)) for var in vc_vars]
self.vc_to_vars.append(vc_vars)
# convert name to tensors
def get_tensor(name):
_, varname = get_op_var_name(name)
return self.graph.get_tensor_by_name(varname)
self.output_tensors = map(get_tensor, self.output_tensors)
def _trigger_epoch(self):
for vc in self.vcs:
vc.before_inference()
sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
outputs = sess.run(self.output_tensors, feed_dict=feed)
for vc, varsmap in zip(self.vcs, self.vc_to_vars):
vc_output = [outputs[k[1]] for k in varsmap]
vc.datapoint(dp, vc_output)
pbar.update()
for vc in self.vcs:
vc.after_inference()
class ScalarStats(Inferencer):
"""
Write stat and summary of some scalar tensor.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the dataset.
"""
def __init__(self, names_to_print, prefix='validation'):
"""
:param names_to_print: list of names of tensors, or just a name
:param prefix: an optional prefix for logging
"""
if not isinstance(names_to_print, list):
self.names = [names_to_print]
else:
self.names = names_to_print
self.prefix = prefix
def _get_output_tensors(self):
return self.names
def _before_inference(self):
self.stats = []
def _datapoint(self, dp, output):
self.stats.append(output)
def _after_inference(self):
self.stats = np.mean(self.stats, axis=0)
assert len(self.stats) == len(self.names)
for stat, name in zip(self.stats, self.names):
opname, _ = get_op_var_name(name)
name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname
self.trainer.summary_writer.add_summary(
create_summary(name, stat), get_global_step())
self.trainer.stat_holder.add_stat(name, stat)
class ClassificationError(Inferencer):
"""
Validate the accuracy from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def __init__(self, wrong_var_name='wrong:0', prefix='validation'):
"""
:param wrong_var_name: name of the `wrong` variable
:param prefix: an optional prefix for logging
"""
self.wrong_var_name = wrong_var_name
self.prefix = prefix
def _get_output_tensors(self):
return [self.wrong_var_name]
def _before_inference(self):
self.err_stat = Accuracy()
def _datapoint(self, dp, outputs):
batch_size = dp[0].shape[0] # assume batched input
wrong = int(outputs[0])
self.err_stat.feed(wrong, batch_size)
def _after_inference(self):
self.trainer.summary_writer.add_summary(
create_summary('{}_error'.format(self.prefix), self.err_stat.accuracy),
get_global_step())
self.trainer.stat_holder.add_stat("{}_error".format(self.prefix), self.err_stat.accuracy)
# -*- coding: UTF-8 -*-
# File: validation_callback.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from six.moves import zip
from ..utils import *
from ..utils.stat import *
from ..tfutils import *
from ..tfutils.summary import *
from .base import Callback, TestCallbackType
__all__ = ['ClassificationError', 'ValidationCallback', 'ValidationStatPrinter']
class ValidationCallback(Callback):
"""
Base class for validation callbacks.
"""
type = TestCallbackType()
def __init__(self, ds, prefix):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation.
"""
self.ds = ds
self.prefix = prefix
def _before_train(self):
self.input_vars = self.trainer.model.reuse_input_vars()
self._find_output_vars()
def get_tensor(self, name):
"""
Get tensor from graph.
"""
return self.graph.get_tensor_by_name(name)
@abstractmethod
def _find_output_vars(self):
""" prepare output variables. Will be called in before_train"""
@abstractmethod
def _get_output_vars(self):
""" return a list of output vars to eval"""
def _run_validation(self):
"""
Eval the vars, generate inputs and outputs
"""
output_vars = self._get_output_vars()
sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input
outputs = sess.run(output_vars, feed_dict=feed)
yield (dp, outputs)
pbar.update()
class ValidationStatPrinter(ValidationCallback):
"""
Write stat and summary of some Op for a validation dataset.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
def __init__(self, ds, names_to_print, prefix='validation'):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print
:param prefix: name to use for this validation.
"""
super(ValidationStatPrinter, self).__init__(ds, prefix)
self.names = names_to_print
def _find_output_vars(self):
self.vars_to_print = [self.get_tensor(
get_op_var_name(n)[1]) for n in self.names]
def _get_output_vars(self):
return self.vars_to_print
def _trigger_epoch(self):
stats = []
for dp, outputs in self._run_validation():
stats.append(outputs)
stats = np.mean(stats, axis=0)
assert len(stats) == len(self.vars_to_print)
for stat, var in zip(stats, self.vars_to_print):
name = var.name.replace(':0', '')
self.trainer.summary_writer.add_summary(create_summary(
'{}_{}'.format(self.prefix, name), stat), self.global_step)
self.trainer.stat_holder.add_stat("{}_{}".format(self.prefix, name), stat)
class ClassificationError(ValidationCallback):
"""
Validate the accuracy from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def __init__(self, ds, prefix='validation',
wrong_var_name='wrong:0'):
"""
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super(ClassificationError, self).__init__(ds, prefix)
self.wrong_var_name = wrong_var_name
def _find_output_vars(self):
self.wrong_var = self.get_tensor(self.wrong_var_name)
def _get_output_vars(self):
return [self.wrong_var]
def _trigger_epoch(self):
err_stat = Accuracy()
for dp, outputs in self._run_validation():
batch_size = dp[0].shape[0] # assume batched input
wrong = outputs[0]
err_stat.feed(wrong, batch_size)
self.trainer.summary_writer.add_summary(create_summary(
'{}_error'.format(self.prefix), err_stat.accuracy), self.global_step)
self.trainer.stat_holder.add_stat("{}_error".format(self.prefix), err_stat.accuracy)
...@@ -36,7 +36,6 @@ def get_global_step(): ...@@ -36,7 +36,6 @@ def get_global_step():
tf.get_default_session(), tf.get_default_session(),
get_global_step_var()) get_global_step_var())
def get_op_var_name(name): def get_op_var_name(name):
""" """
Variable name is assumed to be ``op_name + ':0'`` Variable name is assumed to be ``op_name + ':0'``
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment