Commit dee9d398 authored by Yuxin Wu's avatar Yuxin Wu

restart_episode in RLENV

parent 4e472eb5
......@@ -155,7 +155,7 @@ def play_one_episode(player, func, verbose=False):
if verbose:
print(act)
return act
return player.play_one_episode(f)
return np.mean(player.play_one_episode(f))
def play_model(model_path):
player = get_player(0.013)
......
......@@ -34,7 +34,7 @@ class AtariPlayer(RLEnvironment):
live_lost_as_eoe=True):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param frame_skip: skip every k frames and repeat the action
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable
......@@ -57,10 +57,8 @@ class AtariPlayer(RLEnvironment):
self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check
self.ale.setFloat('repeat_action_probability', 0.0)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
......@@ -77,9 +75,9 @@ class AtariPlayer(RLEnvironment):
self.nullop_start = nullop_start
self.height_range = height_range
self.image_shape = image_shape
self.current_episode_score = StatCounter()
self._reset()
self.current_episode_score = StatCounter()
self.restart_episode()
def _grab_raw_image(self):
"""
......@@ -112,7 +110,9 @@ class AtariPlayer(RLEnvironment):
"""
return len(self.actions)
def _reset(self):
def restart_episode(self):
if self.current_episode_score.count > 0:
self.stats['score'].append(self.current_episode_score.sum)
self.current_episode_score.reset()
self.ale.reset_game()
......@@ -143,8 +143,7 @@ class AtariPlayer(RLEnvironment):
self.current_episode_score.feed(r)
isOver = self.ale.game_over()
if isOver:
self.stats['score'].append(self.current_episode_score.sum)
self._reset()
self.restart_episode()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
return (r, isOver)
......
......@@ -39,6 +39,11 @@ class HistoryFramePlayer(ProxyPlayer):
self.history.append(s)
return (r, isOver)
def restart_episode(self):
super(HistoryFramePlayer, self).restart_episode()
self.history.clear()
self.history.append(self.player.current_state())
class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout
......@@ -63,6 +68,10 @@ class PreventStuckPlayer(ProxyPlayer):
self.act_que.clear()
return (r, isOver)
def restart_episode(self):
super(PreventStuckPlayer, self).restart_episode()
self.act_que.clear()
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode"""
def __init__(self, player, limit):
......@@ -73,8 +82,14 @@ class LimitLengthPlayer(ProxyPlayer):
def action(self, act):
r, isOver = self.player.action(act)
self.cnt += 1
if self.cnt == self.limit:
if self.cnt >= self.limit:
isOver = True
self.player.restart_episode()
if isOver:
self.cnt == 0
print self.cnt
self.cnt = 0
return (r, isOver)
def restart_episode(self):
super(LimitLengthPlayer, self).restart_episode()
self.cnt = 0
......@@ -24,16 +24,20 @@ class RLEnvironment(object):
@abstractmethod
def action(self, act):
"""
Perform an action
Perform an action. Will automatically start a new episode if isOver==True
:params act: the action
:returns: (reward, isOver)
"""
@abstractmethod
def restart_episode(self):
""" Start a new episode, even if the current hasn't ended """
def get_stat(self):
"""
return a dict of statistics (e.g., score) after running for a while
return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
"""
return {}
def reset_stat(self):
""" reset the statistics counter"""
......@@ -63,6 +67,8 @@ class NaiveRLEnvironment(RLEnvironment):
def action(self, act):
self.k = act
return (self.k, self.k > 10)
def restart_episode(self):
pass
class ProxyPlayer(RLEnvironment):
""" Serve as a proxy another player """
......@@ -85,6 +91,5 @@ class ProxyPlayer(RLEnvironment):
def stats(self):
return self.player.stats
def play_one_episode(self, func, stat='score'):
return self.player.play_one_episode(self, func, stat)
def restart_episode(self):
self.player.restart_episode()
......@@ -74,7 +74,7 @@ def add_param_summary(summary_lists):
name = p.name
for rgx, actions in summary_lists:
if not rgx.endswith('$'):
rgx = rgx + '$'
rgx = rgx + '(:0)?$'
if re.match(rgx, name):
for act in actions:
perform(p, act)
......
......@@ -21,14 +21,17 @@ class StatCounter(object):
@property
def average(self):
assert len(self.values)
return np.mean(self.values)
@property
def sum(self):
assert len(self.values)
return np.sum(self.values)
@property
def max(self):
assert len(self.values)
return max(self.values)
class Accuracy(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment