Commit ed7a0793 authored by Yuxin Wu's avatar Yuxin Wu

binaryclassficiation inferencer

parent 73022908
......@@ -15,7 +15,7 @@ from ..tfutils.summary import *
from .base import Callback, TestCallbackType
__all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer']
'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
class Inferencer(object):
__metaclass__ = ABCMeta
......@@ -58,6 +58,12 @@ class Inferencer(object):
def _get_output_tensors(self):
pass
def _scalar_summary(self, name, val):
self.trainer.summary_writer.add_summary(
create_summary(name, val),
get_global_step())
self.trainer.stat_holder.add_stat(name, val)
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
......@@ -190,7 +196,29 @@ class ClassificationError(Inferencer):
self.err_stat.feed(wrong, batch_size)
def _after_inference(self):
self.trainer.summary_writer.add_summary(
create_summary(self.summary_name, self.err_stat.accuracy),
get_global_step())
self.trainer.stat_holder.add_stat(self.summary_name, self.err_stat.accuracy)
self._scalar_summary(self.summary_name, self.err_stat.accuracy)
class BinaryClassificationStats(Inferencer):
def __init__(self, pred_var_name, label_var_name, summary_prefix='val'):
"""
:param pred_var_name: name of the 0/1 prediction tensor.
:param label_var_name: name of the 0/1 label tensor.
"""
self.pred_var_name = pred_var_name
self.label_var_name = label_var_name
self.prefix = summary_prefix
def _get_output_tensors(self):
return [self.pred_var_name, self.label_var_name]
def _before_inference(self):
self.stat = BinaryStatistics()
def _datapoint(self, dp, outputs):
pred, label = outputs
self.stat.feed(pred, label)
def _after_inference(self):
self._scalar_summary(self.prefix + '_precision', self.stat.precision)
self._scalar_summary(self.prefix + '_recall', self.stat.recall)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment