Commit d3802e79 authored by Yuxin Wu's avatar Yuxin Wu

some internal rename

parent 01f54c26
......@@ -46,8 +46,7 @@ class Inferencer(object):
def after_inference(self):
"""
Called after a round of inference ends.
Returns a dict of statistics which will be logged by the :class:`InferenceRunner`.
The inferencer needs to handle other type of logging by itself, if there is any.
Returns a dict of scalar statistics which will be logged to monitors.
"""
return self._after_inference()
......@@ -72,17 +71,17 @@ class ScalarStats(Inferencer):
The value will be averaged over all given datapoints.
"""
def __init__(self, names_to_print, prefix='validation'):
def __init__(self, names, prefix='validation'):
"""
Args:
names_to_print(list or str): list of names or just one name. The
names(list or str): list of names or just one name. The
corresponding tensors have to be scalar.
prefix(str): a prefix for logging
"""
if not isinstance(names_to_print, list):
self.names = [names_to_print]
if not isinstance(names, list):
self.names = [names]
else:
self.names = names_to_print
self.names = names
self.prefix = prefix
def _get_output_tensors(self):
......@@ -128,7 +127,7 @@ class ClassificationError(Inferencer):
wrong_tensor_name(str): name of the ``wrong`` tensor.
The default is the same as the default output name of
:meth:`prediction_incorrect`.
summary_name(str): the name for logging.
summary_name(str): the name to log the error with.
"""
self.wrong_tensor_name = wrong_tensor_name
self.summary_name = summary_name
......@@ -161,7 +160,7 @@ class BinaryClassificationStats(Inferencer):
prediction vector and the label vector.
"""
def __init__(self, pred_tensor_name, label_tensor_name, summary_prefix='val'):
def __init__(self, pred_tensor_name, label_tensor_name, prefix='val'):
"""
Args:
pred_tensor_name(str): name of the 0/1 prediction tensor.
......@@ -169,7 +168,7 @@ class BinaryClassificationStats(Inferencer):
"""
self.pred_tensor_name = pred_tensor_name
self.label_tensor_name = label_tensor_name
self.prefix = summary_prefix
self.prefix = prefix
def _get_output_tensors(self):
return [self.pred_tensor_name, self.label_tensor_name]
......
......@@ -57,6 +57,8 @@ class OutputTensorDispatcher(object):
def summary_inferencer(trainer, infs):
for inf in infs:
ret = inf.after_inference()
if ret is None:
continue
for k, v in six.iteritems(ret):
try:
v = float(v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment