Commit b7de25d9 authored by Yuxin Wu's avatar Yuxin Wu

make inferencer a callback

parent d44d32d6
...@@ -7,6 +7,8 @@ from abc import ABCMeta, abstractmethod ...@@ -7,6 +7,8 @@ from abc import ABCMeta, abstractmethod
import six import six
from six.moves import zip from six.moves import zip
from .base import Callback
from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
...@@ -17,16 +19,16 @@ __all__ = ['ScalarStats', 'Inferencer', ...@@ -17,16 +19,16 @@ __all__ = ['ScalarStats', 'Inferencer',
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Inferencer(object): class Inferencer(Callback):
""" Base class of Inferencer. To be used with :class:`InferenceRunner`. """ """ Base class of Inferencer. To be used with :class:`InferenceRunner`. """
def before_inference(self): def _before_epoch(self):
"""
Called before a new round of inference starts.
"""
self._before_inference() self._before_inference()
def _before_inference(self): def _before_inference(self):
"""
Called before a new round of inference starts.
"""
pass pass
def datapoint(self, output): def datapoint(self, output):
...@@ -43,14 +45,24 @@ class Inferencer(object): ...@@ -43,14 +45,24 @@ class Inferencer(object):
def _datapoint(self, output): def _datapoint(self, output):
pass pass
def after_inference(self): def _trigger_epoch(self):
ret = self._after_inference()
if ret is None:
return
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(self).__name__))
continue
else:
self.trainer.monitors.put_scalar(k, v)
def _after_inference(self):
""" """
Called after a round of inference ends. Called after a round of inference ends.
Returns a dict of scalar statistics which will be logged to monitors. Returns a dict of scalar statistics which will be logged to monitors.
""" """
return self._after_inference()
def _after_inference(self):
pass pass
def get_output_tensors(self): def get_output_tensors(self):
......
...@@ -42,20 +42,6 @@ class InferencerToHook(tf.train.SessionRunHook): ...@@ -42,20 +42,6 @@ class InferencerToHook(tf.train.SessionRunHook):
self._inf.datapoint(run_values.results) self._inf.datapoint(run_values.results)
def summary_inferencer(trainer, infs):
for inf in infs:
ret = inf.after_inference()
if ret is None:
continue
for k, v in six.iteritems(ret):
try:
v = float(v)
trainer.monitors.put_scalar(k, v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Callback): class InferenceRunnerBase(Callback):
""" Base methods for inference runner""" """ Base methods for inference runner"""
...@@ -101,6 +87,9 @@ class InferenceRunnerBase(Callback): ...@@ -101,6 +87,9 @@ class InferenceRunnerBase(Callback):
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks.extend([CallbackToHook(cb) for cb in cbs]) self._hooks.extend([CallbackToHook(cb) for cb in cbs])
for inf in self.infs:
inf.setup_graph(self.trainer)
def _before_train(self): def _before_train(self):
self._hooks.extend(self._extra_hooks) self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
...@@ -111,7 +100,7 @@ class InferenceRunnerBase(Callback): ...@@ -111,7 +100,7 @@ class InferenceRunnerBase(Callback):
def _trigger(self): def _trigger(self):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_epoch()
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_source.reset_state() self._input_source.reset_state()
...@@ -122,8 +111,8 @@ class InferenceRunnerBase(Callback): ...@@ -122,8 +111,8 @@ class InferenceRunnerBase(Callback):
except StopIteration: except StopIteration:
raise RuntimeError( raise RuntimeError(
"[InferenceRunner] input stopped before reaching its size()! " + msg) "[InferenceRunner] input stopped before reaching its size()! " + msg)
for inf in self.infs:
summary_inferencer(self.trainer, self.infs) inf.trigger_epoch()
class InferenceRunner(InferenceRunnerBase): class InferenceRunner(InferenceRunnerBase):
...@@ -193,6 +182,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -193,6 +182,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs]) self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs])
for inf in self.infs:
inf.setup_graph(self.trainer)
class InferencerToHookDataParallel(InferencerToHook): class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size): def __init__(self, inf, fetches, size):
""" """
...@@ -226,7 +218,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -226,7 +218,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _trigger(self): def _trigger(self):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_epoch()
self._input_source.reset_state() self._input_source.reset_state()
total = self._size total = self._size
...@@ -248,4 +240,5 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -248,4 +240,5 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
logger.error( logger.error(
"[DataParallelInferenceRunner] doesn't support InputSource wrappers very well!") "[DataParallelInferenceRunner] doesn't support InputSource wrappers very well!")
logger.error("[DataParallelInferenceRunner] Skipping the rest of the datapoints ...") logger.error("[DataParallelInferenceRunner] Skipping the rest of the datapoints ...")
summary_inferencer(self.trainer, self.infs) for inf in self.infs:
inf.trigger_epoch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment