Commit 375123f5 authored by Yuxin Wu's avatar Yuxin Wu

refactor inferencer

parent 6728b686
......@@ -92,7 +92,7 @@ class Model(ModelDesc):
def _build_graph(self, inputs, is_training):
state, action, reward, next_state, isOver = inputs
self.predict_value = self._get_DQN_prediction(state, is_training)
action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0)
action_onehot = tf.one_hot(action, NUM_ACTIONS)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) #N,
max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward')
......
......@@ -6,6 +6,7 @@ import tensorflow as tf
import numpy as np
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
import six
from six.moves import zip, map
from ..dataflow import DataFlow
......@@ -43,8 +44,9 @@ class Inferencer(object):
def after_inference(self):
"""
Called after a round of inference ends.
Returns a dict of statistics.
"""
self._after_inference()
return self._after_inference()
def _after_inference(self):
pass
......@@ -84,8 +86,6 @@ class InferenceRunner(Callback):
input_names = [x.name for x in self.input_vars]
self.pred_func = self.trainer.get_predict_func(
input_names, self.output_tensors)
for v in self.vcs:
v.trainer = self.trainer
def _find_output_tensors(self):
self.output_tensors = [] # list of names
......@@ -118,7 +118,14 @@ class InferenceRunner(Callback):
pbar.update()
for vc in self.vcs:
vc.after_inference()
ret = vc.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(vc).__name__))
continue
self.trainer.write_scalar_summary(k, v)
class ScalarStats(Inferencer):
"""
......@@ -150,10 +157,12 @@ class ScalarStats(Inferencer):
self.stats = np.mean(self.stats, axis=0)
assert len(self.stats) == len(self.names)
ret = {}
for stat, name in zip(self.stats, self.names):
opname, _ = get_op_var_name(name)
name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname
self.trainer.write_scalar_summary(name, stat)
ret[name] = stat
return ret
class ClassificationError(Inferencer):
"""
......@@ -187,7 +196,7 @@ class ClassificationError(Inferencer):
self.err_stat.feed(wrong, batch_size)
def _after_inference(self):
self.trainer.write_scalar_summary(self.summary_name, self.err_stat.ratio)
return {self.summary_name: self.err_stat.ratio}
class BinaryClassificationStats(Inferencer):
""" Compute precision/recall in binary classification, given the
......@@ -214,5 +223,5 @@ class BinaryClassificationStats(Inferencer):
self.stat.feed(pred, label)
def _after_inference(self):
self.trainer.write_scalar_summary(self.prefix + '_precision', self.stat.precision)
self.trainer.write_scalar_summary(self.prefix + '_recall', self.stat.recall)
return {self.prefix + '_precision': self.stat.precision,
self.prefix + '_recall': self.stat.recall}
......@@ -6,15 +6,6 @@ import tensorflow as tf
import numpy as np
from ..utils import logger
def one_hot(y, num_labels):
"""
:param y: prediction. an Nx1 int tensor.
:param num_labels: an int. number of output classes
:returns: an NxC onehot matrix.
"""
logger.warn("symbf.one_hot is deprecated in favor of more general tf.one_hot")
return tf.one_hot(y, num_labels, 1.0, 0.0, name='one_hot')
def prediction_incorrect(logits, label, topk=1):
"""
:param logits: NxC
......
......@@ -10,6 +10,7 @@ from datetime import datetime
from six.moves import input
import sys
from .utils import memoized
from .fs import mkdir_p
__all__ = []
......
......@@ -11,8 +11,6 @@ import collections
import numpy as np
import six
from . import logger
__all__ = ['change_env',
'map_arg',
'get_rng', 'memoized',
......@@ -50,28 +48,39 @@ class memoized(object):
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
def map_arg(**maps):
"""
......@@ -96,6 +105,7 @@ def get_rng(obj=None):
return np.random.RandomState(seed)
def get_dataset_path(*args):
from . import logger
d = os.environ.get('TENSORPACK_DATASET', None)
if d is None:
d = os.path.abspath(os.path.join(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment