Commit ed7a0793 authored by Yuxin Wu's avatar Yuxin Wu

binaryclassficiation inferencer

parent 73022908
...@@ -15,7 +15,7 @@ from ..tfutils.summary import * ...@@ -15,7 +15,7 @@ from ..tfutils.summary import *
from .base import Callback, TestCallbackType from .base import Callback, TestCallbackType
__all__ = ['InferenceRunner', 'ClassificationError', __all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer'] 'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
class Inferencer(object): class Inferencer(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -58,6 +58,12 @@ class Inferencer(object): ...@@ -58,6 +58,12 @@ class Inferencer(object):
def _get_output_tensors(self): def _get_output_tensors(self):
pass pass
def _scalar_summary(self, name, val):
self.trainer.summary_writer.add_summary(
create_summary(name, val),
get_global_step())
self.trainer.stat_holder.add_stat(name, val)
class InferenceRunner(Callback): class InferenceRunner(Callback):
""" """
A callback that runs different kinds of inferencer. A callback that runs different kinds of inferencer.
...@@ -190,7 +196,29 @@ class ClassificationError(Inferencer): ...@@ -190,7 +196,29 @@ class ClassificationError(Inferencer):
self.err_stat.feed(wrong, batch_size) self.err_stat.feed(wrong, batch_size)
def _after_inference(self): def _after_inference(self):
self.trainer.summary_writer.add_summary( self._scalar_summary(self.summary_name, self.err_stat.accuracy)
create_summary(self.summary_name, self.err_stat.accuracy),
get_global_step()) class BinaryClassificationStats(Inferencer):
self.trainer.stat_holder.add_stat(self.summary_name, self.err_stat.accuracy)
def __init__(self, pred_var_name, label_var_name, summary_prefix='val'):
"""
:param pred_var_name: name of the 0/1 prediction tensor.
:param label_var_name: name of the 0/1 label tensor.
"""
self.pred_var_name = pred_var_name
self.label_var_name = label_var_name
self.prefix = summary_prefix
def _get_output_tensors(self):
return [self.pred_var_name, self.label_var_name]
def _before_inference(self):
self.stat = BinaryStatistics()
def _datapoint(self, dp, outputs):
pred, label = outputs
self.stat.feed(pred, label)
def _after_inference(self):
self._scalar_summary(self.prefix + '_precision', self.stat.precision)
self._scalar_summary(self.prefix + '_recall', self.stat.recall)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment