Commit b61d0722 authored by Yuxin Wu's avatar Yuxin Wu

finish_episode in RLenv

parent 300b2c3a
......@@ -32,7 +32,8 @@ _ALE_LOCK = threading.Lock()
class AtariPlayer(RLEnvironment):
"""
A wrapper for atari emulator.
NOTE: will automatically restart when a real episode ends
Will automatically restart when a real episode ends (isOver might be just
lost of lives but not game over).
"""
def __init__(self, rom_file, viz=0, height_range=(None,None),
frame_skip=4, image_shape=(84, 84), nullop_start=30,
......@@ -129,9 +130,10 @@ class AtariPlayer(RLEnvironment):
def get_action_space(self):
return DiscreteActionSpace(len(self.actions))
def restart_episode(self):
if self.current_episode_score.count > 0:
def finish_episode(self):
self.stats['score'].append(self.current_episode_score.sum)
def restart_episode(self):
self.current_episode_score.reset()
self.ale.reset_game()
......@@ -162,6 +164,7 @@ class AtariPlayer(RLEnvironment):
self.current_episode_score.feed(r)
isOver = self.ale.game_over()
if isOver:
self.finish_episode()
self.restart_episode()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
......
......@@ -40,7 +40,9 @@ class PreventStuckPlayer(ProxyPlayer):
self.act_que.clear()
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode"""
""" Limit the total number of actions in an episode.
Does not auto restart.
"""
def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player)
self.limit = limit
......@@ -51,10 +53,6 @@ class LimitLengthPlayer(ProxyPlayer):
self.cnt += 1
if self.cnt >= self.limit:
isOver = True
self.player.restart_episode()
if isOver:
#print self.cnt, self.player.stats # to see what limit is appropriate
self.cnt = 0
return (r, isOver)
def restart_episode(self):
......@@ -67,6 +65,7 @@ class AutoRestartPlayer(ProxyPlayer):
def action(self, act):
r, isOver = self.player.action(act)
if isOver:
self.player.finish_episode()
self.player.restart_episode()
return r, isOver
......
......@@ -36,6 +36,10 @@ class RLEnvironment(object):
""" Start a new episode, even if the current hasn't ended """
raise NotImplementedError()
def finish_episode(self):
""" get called when an episode finished"""
pass
def get_action_space(self):
""" return an `ActionSpace` instance"""
raise NotImplementedError()
......@@ -112,5 +116,8 @@ class ProxyPlayer(RLEnvironment):
def restart_episode(self):
self.player.restart_episode()
def finish_episode(self):
self.player.finish_episode()
def get_action_space(self):
return self.player.get_action_space()
......@@ -49,14 +49,3 @@ class HistoryFramePlayer(ProxyPlayer):
self.history.clear()
self.history.append(self.player.current_state())
class TimePointHistoryFramePlayer(HistoryFramePlayer):
""" Include history from a list of time points in the past"""
def __init__(self, player, hists):
""" hists: a list of positive integers. 1 means the last frame"""
queue_size = max(hists) + 1
super(TimePointHistoryFramePlayer, self).__init__(player, queue_size)
self.hists = hists
def current_state(self):
# TODO
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment