Commit 194cda0b authored by Yuxin Wu's avatar Yuxin Wu

remove get_stat; more general evaluator

parent b2ec42a8
......@@ -168,7 +168,7 @@ def get_config():
HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
RunOp(lambda: M.update_target_param()),
dataset_train,
PeriodicCallback(Evaluator(EVAL_EPISODE, 'fct/output:0'), 2),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['fct/output']), 2),
]),
# save memory for multiprocess evaluator
session_config=get_default_sess_config(0.6),
......
......@@ -156,13 +156,6 @@ class AtariPlayer(RLEnvironment):
isOver = isOver or newlives < oldlives
return (r, isOver)
def get_stat(self):
try:
return {'avg_score': np.mean(self.stats['score']),
'max_score': float(np.max(self.stats['score'])) }
except ValueError:
return {}
if __name__ == '__main__':
import sys
import time
......
......@@ -28,7 +28,7 @@ class RLEnvironment(object):
def action(self, act):
"""
Perform an action. Will automatically start a new episode if isOver==True
:params act: the action
:param act: the action
:returns: (reward, isOver)
"""
......@@ -40,19 +40,13 @@ class RLEnvironment(object):
""" return an `ActionSpace` instance"""
raise NotImplementedError()
def get_stat(self):
"""
return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
"""
return {}
def reset_stat(self):
""" reset all statistics counter"""
self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'):
""" play one episode for eval.
:params func: call with the state and return an action
:param func: call with the state and return an action
:returns: the score of this episode
"""
while True:
......@@ -102,9 +96,6 @@ class ProxyPlayer(RLEnvironment):
def __init__(self, player):
self.player = player
def get_stat(self):
return self.player.get_stat()
def reset_stat(self):
self.player.reset_stat()
......
......@@ -170,10 +170,14 @@ class ExpReplay(DataFlow, Callback):
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
# log player statistics
stats = self.player.get_stat()
stats = self.player.stats
for k, v in six.iteritems(stats):
if isinstance(v, float):
self.trainer.write_scalar_summary('expreplay/' + k, v)
try:
mean, max = np.mean(v), np.max(v)
self.trainer.write_scalar_summary('expreplay/mean_' + k, mean)
self.trainer.write_scalar_summary('expreplay/max_' + k, max)
except:
pass
self.player.reset_stat()
if __name__ == '__main__':
......
......@@ -10,9 +10,11 @@ from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer']
class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images"""
""" Include history frames in state, or use black images
Assume player will do auto-restart.
"""
def __init__(self, player, hist_len):
""" :params hist_len: total length of the state, including the current
""" :param hist_len: total length of the state, including the current
and `hist_len-1` history"""
super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len)
......
......@@ -47,11 +47,11 @@ class StatHolder(object):
"""
self.print_tag = None if print_tag is None else set(print_tag)
def get_stat_now(self, k):
def get_stat_now(self, key):
"""
Return the value of a stat in the current epoch.
"""
return self.stat_now[k]
return self.stat_now[key]
def finalize(self):
"""
......
......@@ -153,8 +153,8 @@ class MultiThreadAsyncPredictor(object):
def put_task(self, inputs, callback=None):
"""
:params inputs: a data point (list of component) matching input_names (not batched)
:params callback: a callback to get called with the list of outputs
:param inputs: a data point (list of component) matching input_names (not batched)
:param callback: a callback to get called with the list of outputs
:returns: a Future of output."""
f = Future()
if callback is not None:
......
......@@ -44,10 +44,8 @@ def logSoftmax(x):
:param x: NxC tensor.
:returns: NxC tensor.
"""
with tf.op_scope([x], 'logSoftmax'):
z = x - tf.reduce_max(x, 1, keep_dims=True)
logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True))
return logprob
logger.warn("symbf.logSoftmax is deprecated in favor of tf.nn.log_softmax")
return tf.nn.log_softmax(x)
def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
"""
......@@ -73,11 +71,13 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
cost = tf.reduce_mean(cost, name=name)
return cost
def print_stat(x):
def print_stat(x, message=None):
""" a simple print op.
Use it like: x = print_stat(x)
"""
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20)
if message is None:
message = x.op.name
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20, message=message)
def rms(x, name=None):
if name is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment