Commit b61d0722 authored by Yuxin Wu's avatar Yuxin Wu

finish_episode in RLenv

parent 300b2c3a
...@@ -32,7 +32,8 @@ _ALE_LOCK = threading.Lock() ...@@ -32,7 +32,8 @@ _ALE_LOCK = threading.Lock()
class AtariPlayer(RLEnvironment): class AtariPlayer(RLEnvironment):
""" """
A wrapper for atari emulator. A wrapper for atari emulator.
NOTE: will automatically restart when a real episode ends Will automatically restart when a real episode ends (isOver might be just
lost of lives but not game over).
""" """
def __init__(self, rom_file, viz=0, height_range=(None,None), def __init__(self, rom_file, viz=0, height_range=(None,None),
frame_skip=4, image_shape=(84, 84), nullop_start=30, frame_skip=4, image_shape=(84, 84), nullop_start=30,
...@@ -129,9 +130,10 @@ class AtariPlayer(RLEnvironment): ...@@ -129,9 +130,10 @@ class AtariPlayer(RLEnvironment):
def get_action_space(self): def get_action_space(self):
return DiscreteActionSpace(len(self.actions)) return DiscreteActionSpace(len(self.actions))
def finish_episode(self):
self.stats['score'].append(self.current_episode_score.sum)
def restart_episode(self): def restart_episode(self):
if self.current_episode_score.count > 0:
self.stats['score'].append(self.current_episode_score.sum)
self.current_episode_score.reset() self.current_episode_score.reset()
self.ale.reset_game() self.ale.reset_game()
...@@ -162,6 +164,7 @@ class AtariPlayer(RLEnvironment): ...@@ -162,6 +164,7 @@ class AtariPlayer(RLEnvironment):
self.current_episode_score.feed(r) self.current_episode_score.feed(r)
isOver = self.ale.game_over() isOver = self.ale.game_over()
if isOver: if isOver:
self.finish_episode()
self.restart_episode() self.restart_episode()
if self.live_lost_as_eoe: if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives isOver = isOver or newlives < oldlives
......
...@@ -40,7 +40,9 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -40,7 +40,9 @@ class PreventStuckPlayer(ProxyPlayer):
self.act_que.clear() self.act_que.clear()
class LimitLengthPlayer(ProxyPlayer): class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode""" """ Limit the total number of actions in an episode.
Does not auto restart.
"""
def __init__(self, player, limit): def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player) super(LimitLengthPlayer, self).__init__(player)
self.limit = limit self.limit = limit
...@@ -51,10 +53,6 @@ class LimitLengthPlayer(ProxyPlayer): ...@@ -51,10 +53,6 @@ class LimitLengthPlayer(ProxyPlayer):
self.cnt += 1 self.cnt += 1
if self.cnt >= self.limit: if self.cnt >= self.limit:
isOver = True isOver = True
self.player.restart_episode()
if isOver:
#print self.cnt, self.player.stats # to see what limit is appropriate
self.cnt = 0
return (r, isOver) return (r, isOver)
def restart_episode(self): def restart_episode(self):
...@@ -67,6 +65,7 @@ class AutoRestartPlayer(ProxyPlayer): ...@@ -67,6 +65,7 @@ class AutoRestartPlayer(ProxyPlayer):
def action(self, act): def action(self, act):
r, isOver = self.player.action(act) r, isOver = self.player.action(act)
if isOver: if isOver:
self.player.finish_episode()
self.player.restart_episode() self.player.restart_episode()
return r, isOver return r, isOver
......
...@@ -36,6 +36,10 @@ class RLEnvironment(object): ...@@ -36,6 +36,10 @@ class RLEnvironment(object):
""" Start a new episode, even if the current hasn't ended """ """ Start a new episode, even if the current hasn't ended """
raise NotImplementedError() raise NotImplementedError()
def finish_episode(self):
""" get called when an episode finished"""
pass
def get_action_space(self): def get_action_space(self):
""" return an `ActionSpace` instance""" """ return an `ActionSpace` instance"""
raise NotImplementedError() raise NotImplementedError()
...@@ -112,5 +116,8 @@ class ProxyPlayer(RLEnvironment): ...@@ -112,5 +116,8 @@ class ProxyPlayer(RLEnvironment):
def restart_episode(self): def restart_episode(self):
self.player.restart_episode() self.player.restart_episode()
def finish_episode(self):
self.player.finish_episode()
def get_action_space(self): def get_action_space(self):
return self.player.get_action_space() return self.player.get_action_space()
...@@ -49,14 +49,3 @@ class HistoryFramePlayer(ProxyPlayer): ...@@ -49,14 +49,3 @@ class HistoryFramePlayer(ProxyPlayer):
self.history.clear() self.history.clear()
self.history.append(self.player.current_state()) self.history.append(self.player.current_state())
class TimePointHistoryFramePlayer(HistoryFramePlayer):
""" Include history from a list of time points in the past"""
def __init__(self, player, hists):
""" hists: a list of positive integers. 1 means the last frame"""
queue_size = max(hists) + 1
super(TimePointHistoryFramePlayer, self).__init__(player, queue_size)
self.hists = hists
def current_state(self):
# TODO
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment