Commit b747c068 authored by Yuxin Wu's avatar Yuxin Wu

Inferencer API renamed

parent 5b681a95
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
from abc import ABCMeta, abstractmethod from abc import ABCMeta
import six import six
from six.moves import zip from six.moves import zip
...@@ -15,12 +15,11 @@ from ..tfutils.common import get_op_tensor_name ...@@ -15,12 +15,11 @@ from ..tfutils.common import get_op_tensor_name
__all__ = ['ScalarStats', 'Inferencer', __all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats'] 'ClassificationError', 'BinaryClassificationStats']
# TODO rename get_output_tensors to get_output_names
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Inferencer(Callback): class Inferencer(Callback):
""" Base class of Inferencer. To be used with :class:`InferenceRunner`. """ """ Base class of Inferencer.
Inferencer is a special kind of callback that should be called by :class:`InferenceRunner`. """
def _before_epoch(self): def _before_epoch(self):
self._before_inference() self._before_inference()
...@@ -31,20 +30,6 @@ class Inferencer(Callback): ...@@ -31,20 +30,6 @@ class Inferencer(Callback):
""" """
pass pass
def datapoint(self, output):
"""
Called after each new datapoint finished the forward inference.
Args:
output(list): list of output this inferencer needs. Has the same
length as ``self.get_output_tensors()``.
"""
self._datapoint(output)
@abstractmethod
def _datapoint(self, output):
pass
def _trigger_epoch(self): def _trigger_epoch(self):
ret = self._after_inference() ret = self._after_inference()
if ret is None: if ret is None:
...@@ -65,17 +50,44 @@ class Inferencer(Callback): ...@@ -65,17 +50,44 @@ class Inferencer(Callback):
""" """
pass pass
def get_output_tensors(self): def get_fetches(self):
""" """
Return a list of tensor names (guaranteed not op name) this inferencer needs. Return a list of tensor names (guaranteed not op name) this inferencer needs.
""" """
try:
ret = self._get_fetches()
except NotImplementedError:
logger.warn("Inferencer._get_output_tensors was renamed to _get_fetches")
ret = self._get_output_tensors() ret = self._get_output_tensors()
return [get_op_tensor_name(n)[1] for n in ret] return [get_op_tensor_name(n)[1] for n in ret]
@abstractmethod
def _get_output_tensors(self): def _get_output_tensors(self):
pass pass
def _get_fetches(self):
raise NotImplementedError()
def on_fetches(self, results):
"""
Called after each new datapoint finished the forward inference.
Args:
results(list): list of results this inferencer fetched. Has the same
length as ``self._get_fetches()``.
"""
try:
self._on_fetches(results)
except NotImplementedError:
logger.warn("Inferencer._datapoint was renamed to _on_fetches")
self._datapoint(results)
def _datapoint(self, results):
pass
def _on_fetches(self, results):
raise NotImplementedError()
class ScalarStats(Inferencer): class ScalarStats(Inferencer):
""" """
...@@ -96,13 +108,13 @@ class ScalarStats(Inferencer): ...@@ -96,13 +108,13 @@ class ScalarStats(Inferencer):
self.names = names self.names = names
self.prefix = prefix self.prefix = prefix
def _get_output_tensors(self):
return self.names
def _before_inference(self): def _before_inference(self):
self.stats = [] self.stats = []
def _datapoint(self, output): def _get_fetches(self):
return self.names
def _on_fetches(self, output):
self.stats.append(output) self.stats.append(output)
def _after_inference(self): def _after_inference(self):
...@@ -142,13 +154,13 @@ class ClassificationError(Inferencer): ...@@ -142,13 +154,13 @@ class ClassificationError(Inferencer):
self.wrong_tensor_name = wrong_tensor_name self.wrong_tensor_name = wrong_tensor_name
self.summary_name = summary_name self.summary_name = summary_name
def _get_output_tensors(self):
return [self.wrong_tensor_name]
def _before_inference(self): def _before_inference(self):
self.err_stat = RatioCounter() self.err_stat = RatioCounter()
def _datapoint(self, outputs): def _get_fetches(self):
return [self.wrong_tensor_name]
def _on_fetches(self, outputs):
vec = outputs[0] vec = outputs[0]
# TODO put shape assertion into inference-runner # TODO put shape assertion into inference-runner
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_tensor_name) assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_tensor_name)
...@@ -176,13 +188,13 @@ class BinaryClassificationStats(Inferencer): ...@@ -176,13 +188,13 @@ class BinaryClassificationStats(Inferencer):
self.label_tensor_name = label_tensor_name self.label_tensor_name = label_tensor_name
self.prefix = prefix self.prefix = prefix
def _get_output_tensors(self):
return [self.pred_tensor_name, self.label_tensor_name]
def _before_inference(self): def _before_inference(self):
self.stat = BinaryStatistics() self.stat = BinaryStatistics()
def _datapoint(self, outputs): def _get_fetches(self):
return [self.pred_tensor_name, self.label_tensor_name]
def _on_fetches(self, outputs):
pred, label = outputs pred, label = outputs
self.stat.feed(pred, label) self.stat.feed(pred, label)
......
...@@ -39,7 +39,7 @@ class InferencerToHook(tf.train.SessionRunHook): ...@@ -39,7 +39,7 @@ class InferencerToHook(tf.train.SessionRunHook):
return tf.train.SessionRunArgs(fetches=self._fetches) return tf.train.SessionRunArgs(fetches=self._fetches)
def after_run(self, _, run_values): def after_run(self, _, run_values):
self._inf.datapoint(run_values.results) self._inf.on_fetches(run_values.results)
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -136,7 +136,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -136,7 +136,7 @@ class InferenceRunner(InferenceRunnerBase):
input, infs, tower_name=tower_name, extra_hooks=extra_hooks) input, infs, tower_name=tower_name, extra_hooks=extra_hooks)
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_fetches()
fetches = self._tower_handle.get_tensors(out_names) fetches = self._tower_handle.get_tensors(out_names)
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
...@@ -199,16 +199,16 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -199,16 +199,16 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
res = run_values.results res = run_values.results
for i in range(0, len(res), self._sz): for i in range(0, len(res), self._sz):
vals = res[i:i + self._sz] vals = res[i:i + self._sz]
self._inf.datapoint(vals) self._inf.on_fetches(vals)
def _build_hook_parallel(self, inf): def _build_hook_parallel(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_fetches()
sz = len(out_names) sz = len(out_names)
fetches = list(itertools.chain(*[t.get_tensors(out_names) for t in self._handles])) fetches = list(itertools.chain(*[t.get_tensors(out_names) for t in self._handles]))
return self.InferencerToHookDataParallel(inf, fetches, sz) return self.InferencerToHookDataParallel(inf, fetches, sz)
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_fetches()
fetches = self._handles[0].get_tensors(out_names) fetches = self._handles[0].get_tensors(out_names)
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment