Commit dee9d398 authored by Yuxin Wu's avatar Yuxin Wu

restart_episode in RLENV

parent 4e472eb5
...@@ -155,7 +155,7 @@ def play_one_episode(player, func, verbose=False): ...@@ -155,7 +155,7 @@ def play_one_episode(player, func, verbose=False):
if verbose: if verbose:
print(act) print(act)
return act return act
return player.play_one_episode(f) return np.mean(player.play_one_episode(f))
def play_model(model_path): def play_model(model_path):
player = get_player(0.013) player = get_player(0.013)
......
...@@ -34,7 +34,7 @@ class AtariPlayer(RLEnvironment): ...@@ -34,7 +34,7 @@ class AtariPlayer(RLEnvironment):
live_lost_as_eoe=True): live_lost_as_eoe=True):
""" """
:param rom_file: path to the rom :param rom_file: path to the rom
:param frame_skip: skip every k frames :param frame_skip: skip every k frames and repeat the action
:param image_shape: (w, h) :param image_shape: (w, h)
:param height_range: (h1, h2) to cut :param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable :param viz: the delay. visualize the game while running. 0 to disable
...@@ -57,10 +57,8 @@ class AtariPlayer(RLEnvironment): ...@@ -57,10 +57,8 @@ class AtariPlayer(RLEnvironment):
self.ale.setBool('color_averaging', False) self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check # manual.pdf suggests otherwise. may need to check
self.ale.setFloat('repeat_action_probability', 0.0) self.ale.setFloat('repeat_action_probability', 0.0)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet() self.actions = self.ale.getMinimalActionSet()
...@@ -77,9 +75,9 @@ class AtariPlayer(RLEnvironment): ...@@ -77,9 +75,9 @@ class AtariPlayer(RLEnvironment):
self.nullop_start = nullop_start self.nullop_start = nullop_start
self.height_range = height_range self.height_range = height_range
self.image_shape = image_shape self.image_shape = image_shape
self.current_episode_score = StatCounter()
self._reset() self.current_episode_score = StatCounter()
self.restart_episode()
def _grab_raw_image(self): def _grab_raw_image(self):
""" """
...@@ -112,7 +110,9 @@ class AtariPlayer(RLEnvironment): ...@@ -112,7 +110,9 @@ class AtariPlayer(RLEnvironment):
""" """
return len(self.actions) return len(self.actions)
def _reset(self): def restart_episode(self):
if self.current_episode_score.count > 0:
self.stats['score'].append(self.current_episode_score.sum)
self.current_episode_score.reset() self.current_episode_score.reset()
self.ale.reset_game() self.ale.reset_game()
...@@ -143,8 +143,7 @@ class AtariPlayer(RLEnvironment): ...@@ -143,8 +143,7 @@ class AtariPlayer(RLEnvironment):
self.current_episode_score.feed(r) self.current_episode_score.feed(r)
isOver = self.ale.game_over() isOver = self.ale.game_over()
if isOver: if isOver:
self.stats['score'].append(self.current_episode_score.sum) self.restart_episode()
self._reset()
if self.live_lost_as_eoe: if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives isOver = isOver or newlives < oldlives
return (r, isOver) return (r, isOver)
......
...@@ -39,6 +39,11 @@ class HistoryFramePlayer(ProxyPlayer): ...@@ -39,6 +39,11 @@ class HistoryFramePlayer(ProxyPlayer):
self.history.append(s) self.history.append(s)
return (r, isOver) return (r, isOver)
def restart_episode(self):
super(HistoryFramePlayer, self).restart_episode()
self.history.clear()
self.history.append(self.player.current_state())
class PreventStuckPlayer(ProxyPlayer): class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op) """ Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout by inserting a different action. Useful in games such as Atari Breakout
...@@ -63,6 +68,10 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -63,6 +68,10 @@ class PreventStuckPlayer(ProxyPlayer):
self.act_que.clear() self.act_que.clear()
return (r, isOver) return (r, isOver)
def restart_episode(self):
super(PreventStuckPlayer, self).restart_episode()
self.act_que.clear()
class LimitLengthPlayer(ProxyPlayer): class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode""" """ Limit the total number of actions in an episode"""
def __init__(self, player, limit): def __init__(self, player, limit):
...@@ -73,8 +82,14 @@ class LimitLengthPlayer(ProxyPlayer): ...@@ -73,8 +82,14 @@ class LimitLengthPlayer(ProxyPlayer):
def action(self, act): def action(self, act):
r, isOver = self.player.action(act) r, isOver = self.player.action(act)
self.cnt += 1 self.cnt += 1
if self.cnt == self.limit: if self.cnt >= self.limit:
isOver = True isOver = True
self.player.restart_episode()
if isOver: if isOver:
self.cnt == 0 print self.cnt
self.cnt = 0
return (r, isOver) return (r, isOver)
def restart_episode(self):
super(LimitLengthPlayer, self).restart_episode()
self.cnt = 0
...@@ -24,16 +24,20 @@ class RLEnvironment(object): ...@@ -24,16 +24,20 @@ class RLEnvironment(object):
@abstractmethod @abstractmethod
def action(self, act): def action(self, act):
""" """
Perform an action Perform an action. Will automatically start a new episode if isOver==True
:params act: the action :params act: the action
:returns: (reward, isOver) :returns: (reward, isOver)
""" """
@abstractmethod @abstractmethod
def restart_episode(self):
""" Start a new episode, even if the current hasn't ended """
def get_stat(self): def get_stat(self):
""" """
return a dict of statistics (e.g., score) after running for a while return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
""" """
return {}
def reset_stat(self): def reset_stat(self):
""" reset the statistics counter""" """ reset the statistics counter"""
...@@ -63,6 +67,8 @@ class NaiveRLEnvironment(RLEnvironment): ...@@ -63,6 +67,8 @@ class NaiveRLEnvironment(RLEnvironment):
def action(self, act): def action(self, act):
self.k = act self.k = act
return (self.k, self.k > 10) return (self.k, self.k > 10)
def restart_episode(self):
pass
class ProxyPlayer(RLEnvironment): class ProxyPlayer(RLEnvironment):
""" Serve as a proxy another player """ """ Serve as a proxy another player """
...@@ -85,6 +91,5 @@ class ProxyPlayer(RLEnvironment): ...@@ -85,6 +91,5 @@ class ProxyPlayer(RLEnvironment):
def stats(self): def stats(self):
return self.player.stats return self.player.stats
def play_one_episode(self, func, stat='score'): def restart_episode(self):
return self.player.play_one_episode(self, func, stat) self.player.restart_episode()
...@@ -74,7 +74,7 @@ def add_param_summary(summary_lists): ...@@ -74,7 +74,7 @@ def add_param_summary(summary_lists):
name = p.name name = p.name
for rgx, actions in summary_lists: for rgx, actions in summary_lists:
if not rgx.endswith('$'): if not rgx.endswith('$'):
rgx = rgx + '$' rgx = rgx + '(:0)?$'
if re.match(rgx, name): if re.match(rgx, name):
for act in actions: for act in actions:
perform(p, act) perform(p, act)
......
...@@ -21,14 +21,17 @@ class StatCounter(object): ...@@ -21,14 +21,17 @@ class StatCounter(object):
@property @property
def average(self): def average(self):
assert len(self.values)
return np.mean(self.values) return np.mean(self.values)
@property @property
def sum(self): def sum(self):
assert len(self.values)
return np.sum(self.values) return np.sum(self.values)
@property @property
def max(self): def max(self):
assert len(self.values)
return max(self.values) return max(self.values)
class Accuracy(object): class Accuracy(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment