Commit 375123f5 authored by Yuxin Wu's avatar Yuxin Wu

refactor inferencer

parent 6728b686
...@@ -92,7 +92,7 @@ class Model(ModelDesc): ...@@ -92,7 +92,7 @@ class Model(ModelDesc):
def _build_graph(self, inputs, is_training): def _build_graph(self, inputs, is_training):
state, action, reward, next_state, isOver = inputs state, action, reward, next_state, isOver = inputs
self.predict_value = self._get_DQN_prediction(state, is_training) self.predict_value = self._get_DQN_prediction(state, is_training)
action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0) action_onehot = tf.one_hot(action, NUM_ACTIONS)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) #N, pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) #N,
max_pred_reward = tf.reduce_mean(tf.reduce_max( max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import six
from six.moves import zip, map from six.moves import zip, map
from ..dataflow import DataFlow from ..dataflow import DataFlow
...@@ -43,8 +44,9 @@ class Inferencer(object): ...@@ -43,8 +44,9 @@ class Inferencer(object):
def after_inference(self): def after_inference(self):
""" """
Called after a round of inference ends. Called after a round of inference ends.
Returns a dict of statistics.
""" """
self._after_inference() return self._after_inference()
def _after_inference(self): def _after_inference(self):
pass pass
...@@ -84,8 +86,6 @@ class InferenceRunner(Callback): ...@@ -84,8 +86,6 @@ class InferenceRunner(Callback):
input_names = [x.name for x in self.input_vars] input_names = [x.name for x in self.input_vars]
self.pred_func = self.trainer.get_predict_func( self.pred_func = self.trainer.get_predict_func(
input_names, self.output_tensors) input_names, self.output_tensors)
for v in self.vcs:
v.trainer = self.trainer
def _find_output_tensors(self): def _find_output_tensors(self):
self.output_tensors = [] # list of names self.output_tensors = [] # list of names
...@@ -118,7 +118,14 @@ class InferenceRunner(Callback): ...@@ -118,7 +118,14 @@ class InferenceRunner(Callback):
pbar.update() pbar.update()
for vc in self.vcs: for vc in self.vcs:
vc.after_inference() ret = vc.after_inference()
for k, v in six.iteritems(ret):
try:
v = float(v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(vc).__name__))
continue
self.trainer.write_scalar_summary(k, v)
class ScalarStats(Inferencer): class ScalarStats(Inferencer):
""" """
...@@ -150,10 +157,12 @@ class ScalarStats(Inferencer): ...@@ -150,10 +157,12 @@ class ScalarStats(Inferencer):
self.stats = np.mean(self.stats, axis=0) self.stats = np.mean(self.stats, axis=0)
assert len(self.stats) == len(self.names) assert len(self.stats) == len(self.names)
ret = {}
for stat, name in zip(self.stats, self.names): for stat, name in zip(self.stats, self.names):
opname, _ = get_op_var_name(name) opname, _ = get_op_var_name(name)
name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname
self.trainer.write_scalar_summary(name, stat) ret[name] = stat
return ret
class ClassificationError(Inferencer): class ClassificationError(Inferencer):
""" """
...@@ -187,7 +196,7 @@ class ClassificationError(Inferencer): ...@@ -187,7 +196,7 @@ class ClassificationError(Inferencer):
self.err_stat.feed(wrong, batch_size) self.err_stat.feed(wrong, batch_size)
def _after_inference(self): def _after_inference(self):
self.trainer.write_scalar_summary(self.summary_name, self.err_stat.ratio) return {self.summary_name: self.err_stat.ratio}
class BinaryClassificationStats(Inferencer): class BinaryClassificationStats(Inferencer):
""" Compute precision/recall in binary classification, given the """ Compute precision/recall in binary classification, given the
...@@ -214,5 +223,5 @@ class BinaryClassificationStats(Inferencer): ...@@ -214,5 +223,5 @@ class BinaryClassificationStats(Inferencer):
self.stat.feed(pred, label) self.stat.feed(pred, label)
def _after_inference(self): def _after_inference(self):
self.trainer.write_scalar_summary(self.prefix + '_precision', self.stat.precision) return {self.prefix + '_precision': self.stat.precision,
self.trainer.write_scalar_summary(self.prefix + '_recall', self.stat.recall) self.prefix + '_recall': self.stat.recall}
...@@ -6,15 +6,6 @@ import tensorflow as tf ...@@ -6,15 +6,6 @@ import tensorflow as tf
import numpy as np import numpy as np
from ..utils import logger from ..utils import logger
def one_hot(y, num_labels):
"""
:param y: prediction. an Nx1 int tensor.
:param num_labels: an int. number of output classes
:returns: an NxC onehot matrix.
"""
logger.warn("symbf.one_hot is deprecated in favor of more general tf.one_hot")
return tf.one_hot(y, num_labels, 1.0, 0.0, name='one_hot')
def prediction_incorrect(logits, label, topk=1): def prediction_incorrect(logits, label, topk=1):
""" """
:param logits: NxC :param logits: NxC
......
...@@ -10,6 +10,7 @@ from datetime import datetime ...@@ -10,6 +10,7 @@ from datetime import datetime
from six.moves import input from six.moves import input
import sys import sys
from .utils import memoized
from .fs import mkdir_p from .fs import mkdir_p
__all__ = [] __all__ = []
......
...@@ -11,8 +11,6 @@ import collections ...@@ -11,8 +11,6 @@ import collections
import numpy as np import numpy as np
import six import six
from . import logger
__all__ = ['change_env', __all__ = ['change_env',
'map_arg', 'map_arg',
'get_rng', 'memoized', 'get_rng', 'memoized',
...@@ -50,28 +48,39 @@ class memoized(object): ...@@ -50,28 +48,39 @@ class memoized(object):
(not reevaluated). (not reevaluated).
''' '''
def __init__(self, func): def __init__(self, func):
self.func = func self.func = func
self.cache = {} self.cache = {}
def __call__(self, *args): def __call__(self, *args):
if not isinstance(args, collections.Hashable): if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance. # uncacheable. a list, for instance.
# better to not cache than blow up. # better to not cache than blow up.
return self.func(*args) return self.func(*args)
if args in self.cache: if args in self.cache:
return self.cache[args] return self.cache[args]
else: else:
value = self.func(*args) value = self.func(*args)
self.cache[args] = value self.cache[args] = value
return value return value
def __repr__(self): def __repr__(self):
'''Return the function's docstring.''' '''Return the function's docstring.'''
return self.func.__doc__ return self.func.__doc__
def __get__(self, obj, objtype): def __get__(self, obj, objtype):
'''Support instance methods.''' '''Support instance methods.'''
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
def map_arg(**maps): def map_arg(**maps):
""" """
...@@ -96,6 +105,7 @@ def get_rng(obj=None): ...@@ -96,6 +105,7 @@ def get_rng(obj=None):
return np.random.RandomState(seed) return np.random.RandomState(seed)
def get_dataset_path(*args): def get_dataset_path(*args):
from . import logger
d = os.environ.get('TENSORPACK_DATASET', None) d = os.environ.get('TENSORPACK_DATASET', None)
if d is None: if d is None:
d = os.path.abspath(os.path.join( d = os.path.abspath(os.path.join(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment