Commit d3802e79 authored by Yuxin Wu's avatar Yuxin Wu

some internal rename

parent 01f54c26
...@@ -46,8 +46,7 @@ class Inferencer(object): ...@@ -46,8 +46,7 @@ class Inferencer(object):
def after_inference(self): def after_inference(self):
""" """
Called after a round of inference ends. Called after a round of inference ends.
Returns a dict of statistics which will be logged by the :class:`InferenceRunner`. Returns a dict of scalar statistics which will be logged to monitors.
The inferencer needs to handle other type of logging by itself, if there is any.
""" """
return self._after_inference() return self._after_inference()
...@@ -72,17 +71,17 @@ class ScalarStats(Inferencer): ...@@ -72,17 +71,17 @@ class ScalarStats(Inferencer):
The value will be averaged over all given datapoints. The value will be averaged over all given datapoints.
""" """
def __init__(self, names_to_print, prefix='validation'): def __init__(self, names, prefix='validation'):
""" """
Args: Args:
names_to_print(list or str): list of names or just one name. The names(list or str): list of names or just one name. The
corresponding tensors have to be scalar. corresponding tensors have to be scalar.
prefix(str): a prefix for logging prefix(str): a prefix for logging
""" """
if not isinstance(names_to_print, list): if not isinstance(names, list):
self.names = [names_to_print] self.names = [names]
else: else:
self.names = names_to_print self.names = names
self.prefix = prefix self.prefix = prefix
def _get_output_tensors(self): def _get_output_tensors(self):
...@@ -128,7 +127,7 @@ class ClassificationError(Inferencer): ...@@ -128,7 +127,7 @@ class ClassificationError(Inferencer):
wrong_tensor_name(str): name of the ``wrong`` tensor. wrong_tensor_name(str): name of the ``wrong`` tensor.
The default is the same as the default output name of The default is the same as the default output name of
:meth:`prediction_incorrect`. :meth:`prediction_incorrect`.
summary_name(str): the name for logging. summary_name(str): the name to log the error with.
""" """
self.wrong_tensor_name = wrong_tensor_name self.wrong_tensor_name = wrong_tensor_name
self.summary_name = summary_name self.summary_name = summary_name
...@@ -161,7 +160,7 @@ class BinaryClassificationStats(Inferencer): ...@@ -161,7 +160,7 @@ class BinaryClassificationStats(Inferencer):
prediction vector and the label vector. prediction vector and the label vector.
""" """
def __init__(self, pred_tensor_name, label_tensor_name, summary_prefix='val'): def __init__(self, pred_tensor_name, label_tensor_name, prefix='val'):
""" """
Args: Args:
pred_tensor_name(str): name of the 0/1 prediction tensor. pred_tensor_name(str): name of the 0/1 prediction tensor.
...@@ -169,7 +168,7 @@ class BinaryClassificationStats(Inferencer): ...@@ -169,7 +168,7 @@ class BinaryClassificationStats(Inferencer):
""" """
self.pred_tensor_name = pred_tensor_name self.pred_tensor_name = pred_tensor_name
self.label_tensor_name = label_tensor_name self.label_tensor_name = label_tensor_name
self.prefix = summary_prefix self.prefix = prefix
def _get_output_tensors(self): def _get_output_tensors(self):
return [self.pred_tensor_name, self.label_tensor_name] return [self.pred_tensor_name, self.label_tensor_name]
......
...@@ -57,6 +57,8 @@ class OutputTensorDispatcher(object): ...@@ -57,6 +57,8 @@ class OutputTensorDispatcher(object):
def summary_inferencer(trainer, infs): def summary_inferencer(trainer, infs):
for inf in infs: for inf in infs:
ret = inf.after_inference() ret = inf.after_inference()
if ret is None:
continue
for k, v in six.iteritems(ret): for k, v in six.iteritems(ret):
try: try:
v = float(v) v = float(v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment