Commit 194cda0b authored by Yuxin Wu's avatar Yuxin Wu

remove get_stat; more general evaluator

parent b2ec42a8
...@@ -168,7 +168,7 @@ def get_config(): ...@@ -168,7 +168,7 @@ def get_config():
HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'), HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
RunOp(lambda: M.update_target_param()), RunOp(lambda: M.update_target_param()),
dataset_train, dataset_train,
PeriodicCallback(Evaluator(EVAL_EPISODE, 'fct/output:0'), 2), PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['fct/output']), 2),
]), ]),
# save memory for multiprocess evaluator # save memory for multiprocess evaluator
session_config=get_default_sess_config(0.6), session_config=get_default_sess_config(0.6),
......
...@@ -156,13 +156,6 @@ class AtariPlayer(RLEnvironment): ...@@ -156,13 +156,6 @@ class AtariPlayer(RLEnvironment):
isOver = isOver or newlives < oldlives isOver = isOver or newlives < oldlives
return (r, isOver) return (r, isOver)
def get_stat(self):
try:
return {'avg_score': np.mean(self.stats['score']),
'max_score': float(np.max(self.stats['score'])) }
except ValueError:
return {}
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
import time import time
......
...@@ -28,7 +28,7 @@ class RLEnvironment(object): ...@@ -28,7 +28,7 @@ class RLEnvironment(object):
def action(self, act): def action(self, act):
""" """
Perform an action. Will automatically start a new episode if isOver==True Perform an action. Will automatically start a new episode if isOver==True
:params act: the action :param act: the action
:returns: (reward, isOver) :returns: (reward, isOver)
""" """
...@@ -40,19 +40,13 @@ class RLEnvironment(object): ...@@ -40,19 +40,13 @@ class RLEnvironment(object):
""" return an `ActionSpace` instance""" """ return an `ActionSpace` instance"""
raise NotImplementedError() raise NotImplementedError()
def get_stat(self):
"""
return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
"""
return {}
def reset_stat(self): def reset_stat(self):
""" reset all statistics counter""" """ reset all statistics counter"""
self.stats = defaultdict(list) self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'): def play_one_episode(self, func, stat='score'):
""" play one episode for eval. """ play one episode for eval.
:params func: call with the state and return an action :param func: call with the state and return an action
:returns: the score of this episode :returns: the score of this episode
""" """
while True: while True:
...@@ -102,9 +96,6 @@ class ProxyPlayer(RLEnvironment): ...@@ -102,9 +96,6 @@ class ProxyPlayer(RLEnvironment):
def __init__(self, player): def __init__(self, player):
self.player = player self.player = player
def get_stat(self):
return self.player.get_stat()
def reset_stat(self): def reset_stat(self):
self.player.reset_stat() self.player.reset_stat()
......
...@@ -170,10 +170,14 @@ class ExpReplay(DataFlow, Callback): ...@@ -170,10 +170,14 @@ class ExpReplay(DataFlow, Callback):
self.exploration -= self.exploration_epoch_anneal self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration)) logger.info("Exploration changed to {}".format(self.exploration))
# log player statistics # log player statistics
stats = self.player.get_stat() stats = self.player.stats
for k, v in six.iteritems(stats): for k, v in six.iteritems(stats):
if isinstance(v, float): try:
self.trainer.write_scalar_summary('expreplay/' + k, v) mean, max = np.mean(v), np.max(v)
self.trainer.write_scalar_summary('expreplay/mean_' + k, mean)
self.trainer.write_scalar_summary('expreplay/max_' + k, max)
except:
pass
self.player.reset_stat() self.player.reset_stat()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -10,9 +10,11 @@ from .envbase import ProxyPlayer ...@@ -10,9 +10,11 @@ from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer'] __all__ = ['HistoryFramePlayer']
class HistoryFramePlayer(ProxyPlayer): class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images""" """ Include history frames in state, or use black images
Assume player will do auto-restart.
"""
def __init__(self, player, hist_len): def __init__(self, player, hist_len):
""" :params hist_len: total length of the state, including the current """ :param hist_len: total length of the state, including the current
and `hist_len-1` history""" and `hist_len-1` history"""
super(HistoryFramePlayer, self).__init__(player) super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len) self.history = deque(maxlen=hist_len)
......
...@@ -47,11 +47,11 @@ class StatHolder(object): ...@@ -47,11 +47,11 @@ class StatHolder(object):
""" """
self.print_tag = None if print_tag is None else set(print_tag) self.print_tag = None if print_tag is None else set(print_tag)
def get_stat_now(self, k): def get_stat_now(self, key):
""" """
Return the value of a stat in the current epoch. Return the value of a stat in the current epoch.
""" """
return self.stat_now[k] return self.stat_now[key]
def finalize(self): def finalize(self):
""" """
......
...@@ -153,8 +153,8 @@ class MultiThreadAsyncPredictor(object): ...@@ -153,8 +153,8 @@ class MultiThreadAsyncPredictor(object):
def put_task(self, inputs, callback=None): def put_task(self, inputs, callback=None):
""" """
:params inputs: a data point (list of component) matching input_names (not batched) :param inputs: a data point (list of component) matching input_names (not batched)
:params callback: a callback to get called with the list of outputs :param callback: a callback to get called with the list of outputs
:returns: a Future of output.""" :returns: a Future of output."""
f = Future() f = Future()
if callback is not None: if callback is not None:
......
...@@ -44,10 +44,8 @@ def logSoftmax(x): ...@@ -44,10 +44,8 @@ def logSoftmax(x):
:param x: NxC tensor. :param x: NxC tensor.
:returns: NxC tensor. :returns: NxC tensor.
""" """
with tf.op_scope([x], 'logSoftmax'): logger.warn("symbf.logSoftmax is deprecated in favor of tf.nn.log_softmax")
z = x - tf.reduce_max(x, 1, keep_dims=True) return tf.nn.log_softmax(x)
logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True))
return logprob
def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'): def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
""" """
...@@ -73,11 +71,13 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l ...@@ -73,11 +71,13 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
cost = tf.reduce_mean(cost, name=name) cost = tf.reduce_mean(cost, name=name)
return cost return cost
def print_stat(x): def print_stat(x, message=None):
""" a simple print op. """ a simple print op.
Use it like: x = print_stat(x) Use it like: x = print_stat(x)
""" """
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20) if message is None:
message = x.op.name
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20, message=message)
def rms(x, name=None): def rms(x, name=None):
if name is None: if name is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment