Commit b7de25d9 authored by Yuxin Wu's avatar Yuxin Wu

make inferencer a callback

parent d44d32d6
......@@ -7,6 +7,8 @@ from abc import ABCMeta, abstractmethod
import six
from six.moves import zip
from .base import Callback
from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils.common import get_op_tensor_name
......@@ -17,16 +19,16 @@ __all__ = ['ScalarStats', 'Inferencer',
@six.add_metaclass(ABCMeta)
class Inferencer(object):
class Inferencer(Callback):
""" Base class of Inferencer. To be used with :class:`InferenceRunner`. """
def before_inference(self):
"""
Called before a new round of inference starts.
"""
def _before_epoch(self):
self._before_inference()
def _before_inference(self):
"""
Called before a new round of inference starts.
"""
pass
def datapoint(self, output):
......@@ -43,14 +45,24 @@ class Inferencer(object):
def _datapoint(self, output):
pass
def after_inference(self):
def _trigger_epoch(self):
ret = self._after_inference()
if ret is None:
return
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(self).__name__))
continue
else:
self.trainer.monitors.put_scalar(k, v)
def _after_inference(self):
"""
Called after a round of inference ends.
Returns a dict of scalar statistics which will be logged to monitors.
"""
return self._after_inference()
def _after_inference(self):
pass
def get_output_tensors(self):
......
......@@ -42,20 +42,6 @@ class InferencerToHook(tf.train.SessionRunHook):
self._inf.datapoint(run_values.results)
def summary_inferencer(trainer, infs):
for inf in infs:
ret = inf.after_inference()
if ret is None:
continue
for k, v in six.iteritems(ret):
try:
v = float(v)
trainer.monitors.put_scalar(k, v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
@six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Callback):
""" Base methods for inference runner"""
......@@ -101,6 +87,9 @@ class InferenceRunnerBase(Callback):
self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks.extend([CallbackToHook(cb) for cb in cbs])
for inf in self.infs:
inf.setup_graph(self.trainer)
def _before_train(self):
self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
......@@ -111,7 +100,7 @@ class InferenceRunnerBase(Callback):
def _trigger(self):
for inf in self.infs:
inf.before_inference()
inf.before_epoch()
# iterate over the data, and run the hooked session
self._input_source.reset_state()
......@@ -122,8 +111,8 @@ class InferenceRunnerBase(Callback):
except StopIteration:
raise RuntimeError(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
summary_inferencer(self.trainer, self.infs)
for inf in self.infs:
inf.trigger_epoch()
class InferenceRunner(InferenceRunnerBase):
......@@ -193,6 +182,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs])
for inf in self.infs:
inf.setup_graph(self.trainer)
class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size):
"""
......@@ -226,7 +218,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _trigger(self):
for inf in self.infs:
inf.before_inference()
inf.before_epoch()
self._input_source.reset_state()
total = self._size
......@@ -248,4 +240,5 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
logger.error(
"[DataParallelInferenceRunner] doesn't support InputSource wrappers very well!")
logger.error("[DataParallelInferenceRunner] Skipping the rest of the datapoints ...")
summary_inferencer(self.trainer, self.infs)
for inf in self.infs:
inf.trigger_epoch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment