Commit 90a74f02 authored by Yuxin Wu's avatar Yuxin Wu

stats for RLEnv

parent 48115f68
......@@ -255,8 +255,8 @@ def get_config(romfile):
output = output.strip()
output = output[output.find(']')+1:]
mean, maximum = re.findall('[0-9\.]+', output)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum)
self.trainer.write_scalar_summary('eval_mean_score', mean)
self.trainer.write_scalar_summary('eval_max_score', maximum)
return TrainConfig(
dataset=dataset_train,
......
......@@ -6,9 +6,10 @@
from abc import abstractmethod, ABCMeta
import random
import numpy as np
from collections import deque, namedtuple
from collections import deque, namedtuple, defaultdict
from tqdm import tqdm
import cv2
import six
from .base import DataFlow
from tensorpack.utils import *
......@@ -26,6 +27,9 @@ Experience = namedtuple('Experience',
class RLEnvironment(object):
__meta__ = ABCMeta
def __init__(self):
self.reset_stat()
@abstractmethod
def current_state(self):
"""
......@@ -40,6 +44,16 @@ class RLEnvironment(object):
:returns: (reward, isOver)
"""
@abstractmethod
def get_stat(self):
"""
return a dict of statistics (e.g., score) after running for a while
"""
def reset_stat(self):
""" reset the statistics counter"""
self.stats = defaultdict(list)
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
def __init__(self):
......@@ -67,7 +81,9 @@ class ExpReplay(DataFlow, Callback):
exploration=1,
end_exploration=0.1,
exploration_epoch_anneal=0.002,
reward_clip=None):
reward_clip=None,
new_experience_per_step=1
):
"""
:param predictor: callabale. called with a state, return a distribution
:param player: a `RLEnvironment`
......@@ -117,6 +133,7 @@ class ExpReplay(DataFlow, Callback):
idxs = self.rng.randint(len(self.mem), size=self.batch_size)
batch_exp = [self.mem[k] for k in idxs]
yield self._process_batch(batch_exp)
for _ in range(self.new_experience_per_step):
self._populate_exp()
def _process_batch(self, batch_exp):
......@@ -144,7 +161,11 @@ class ExpReplay(DataFlow, Callback):
if self.exploration > self.end_exploration:
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
stats = self.player.get_stat()
for k, v in six.iteritems(stats):
if isinstance(v, float):
self.trainer.write_scalar_summary('expreplay/' + k, v)
self.player.reset_stat()
if __name__ == '__main__':
......
......@@ -107,17 +107,20 @@ class AtariPlayer(RLEnvironment):
:param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image
"""
super(AtariPlayer, self).__init__()
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.last_act = 0
self.frames = deque(maxlen=hist_len)
self.current_accum_score = 0
self.restart()
def restart(self):
"""
Restart the game and populate frames with the beginning frame
"""
self.current_accum_score = 0
self.frames.clear()
s = self.driver.grab_image()
......@@ -156,11 +159,22 @@ class AtariPlayer(RLEnvironment):
if isOver:
break
s = cv2.resize(s, self.image_shape)
self.current_accum_score += totr
self.frames.append(s)
if isOver:
self.stats['score'].append(self.current_accum_score)
self.restart()
return (totr, isOver)
def get_stat(self):
try:
print self.stats
return {'avg_score': np.mean(self.stats['score']),
'max_score': float(np.max(self.stats['score']))
}
except ValueError:
return {}
if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True)
num = a.get_num_actions()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment