Commit 90a74f02 authored by Yuxin Wu's avatar Yuxin Wu

stats for RLEnv

parent 48115f68
...@@ -255,8 +255,8 @@ def get_config(romfile): ...@@ -255,8 +255,8 @@ def get_config(romfile):
output = output.strip() output = output.strip()
output = output[output.find(']')+1:] output = output[output.find(']')+1:]
mean, maximum = re.findall('[0-9\.]+', output) mean, maximum = re.findall('[0-9\.]+', output)
self.trainer.write_scalar_summary('mean_score', mean) self.trainer.write_scalar_summary('eval_mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum) self.trainer.write_scalar_summary('eval_max_score', maximum)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataset=dataset_train,
......
...@@ -6,9 +6,10 @@ ...@@ -6,9 +6,10 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import random import random
import numpy as np import numpy as np
from collections import deque, namedtuple from collections import deque, namedtuple, defaultdict
from tqdm import tqdm from tqdm import tqdm
import cv2 import cv2
import six
from .base import DataFlow from .base import DataFlow
from tensorpack.utils import * from tensorpack.utils import *
...@@ -26,6 +27,9 @@ Experience = namedtuple('Experience', ...@@ -26,6 +27,9 @@ Experience = namedtuple('Experience',
class RLEnvironment(object): class RLEnvironment(object):
__meta__ = ABCMeta __meta__ = ABCMeta
def __init__(self):
self.reset_stat()
@abstractmethod @abstractmethod
def current_state(self): def current_state(self):
""" """
...@@ -40,6 +44,16 @@ class RLEnvironment(object): ...@@ -40,6 +44,16 @@ class RLEnvironment(object):
:returns: (reward, isOver) :returns: (reward, isOver)
""" """
@abstractmethod
def get_stat(self):
"""
return a dict of statistics (e.g., score) after running for a while
"""
def reset_stat(self):
""" reset the statistics counter"""
self.stats = defaultdict(list)
class NaiveRLEnvironment(RLEnvironment): class NaiveRLEnvironment(RLEnvironment):
""" for testing only""" """ for testing only"""
def __init__(self): def __init__(self):
...@@ -67,7 +81,9 @@ class ExpReplay(DataFlow, Callback): ...@@ -67,7 +81,9 @@ class ExpReplay(DataFlow, Callback):
exploration=1, exploration=1,
end_exploration=0.1, end_exploration=0.1,
exploration_epoch_anneal=0.002, exploration_epoch_anneal=0.002,
reward_clip=None): reward_clip=None,
new_experience_per_step=1
):
""" """
:param predictor: callabale. called with a state, return a distribution :param predictor: callabale. called with a state, return a distribution
:param player: a `RLEnvironment` :param player: a `RLEnvironment`
...@@ -117,7 +133,8 @@ class ExpReplay(DataFlow, Callback): ...@@ -117,7 +133,8 @@ class ExpReplay(DataFlow, Callback):
idxs = self.rng.randint(len(self.mem), size=self.batch_size) idxs = self.rng.randint(len(self.mem), size=self.batch_size)
batch_exp = [self.mem[k] for k in idxs] batch_exp = [self.mem[k] for k in idxs]
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
self._populate_exp() for _ in range(self.new_experience_per_step):
self._populate_exp()
def _process_batch(self, batch_exp): def _process_batch(self, batch_exp):
state_shape = batch_exp[0].state.shape state_shape = batch_exp[0].state.shape
...@@ -144,7 +161,11 @@ class ExpReplay(DataFlow, Callback): ...@@ -144,7 +161,11 @@ class ExpReplay(DataFlow, Callback):
if self.exploration > self.end_exploration: if self.exploration > self.end_exploration:
self.exploration -= self.exploration_epoch_anneal self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration)) logger.info("Exploration changed to {}".format(self.exploration))
stats = self.player.get_stat()
for k, v in six.iteritems(stats):
if isinstance(v, float):
self.trainer.write_scalar_summary('expreplay/' + k, v)
self.player.reset_stat()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -107,17 +107,20 @@ class AtariPlayer(RLEnvironment): ...@@ -107,17 +107,20 @@ class AtariPlayer(RLEnvironment):
:param action_repeat: repeat each action `action_repeat` times and skip those frames :param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image :param image_shape: the shape of the observed image
""" """
super(AtariPlayer, self).__init__()
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.last_act = 0 self.last_act = 0
self.frames = deque(maxlen=hist_len) self.frames = deque(maxlen=hist_len)
self.current_accum_score = 0
self.restart() self.restart()
def restart(self): def restart(self):
""" """
Restart the game and populate frames with the beginning frame Restart the game and populate frames with the beginning frame
""" """
self.current_accum_score = 0
self.frames.clear() self.frames.clear()
s = self.driver.grab_image() s = self.driver.grab_image()
...@@ -156,11 +159,22 @@ class AtariPlayer(RLEnvironment): ...@@ -156,11 +159,22 @@ class AtariPlayer(RLEnvironment):
if isOver: if isOver:
break break
s = cv2.resize(s, self.image_shape) s = cv2.resize(s, self.image_shape)
self.current_accum_score += totr
self.frames.append(s) self.frames.append(s)
if isOver: if isOver:
self.stats['score'].append(self.current_accum_score)
self.restart() self.restart()
return (totr, isOver) return (totr, isOver)
def get_stat(self):
try:
print self.stats
return {'avg_score': np.mean(self.stats['score']),
'max_score': float(np.max(self.stats['score']))
}
except ValueError:
return {}
if __name__ == '__main__': if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True) a = AtariDriver('breakout.bin', viz=True)
num = a.get_num_actions() num = a.get_num_actions()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment