Commit 7da1d899 authored by Yuxin Wu's avatar Yuxin Wu

auto restart player

parent 92073acc
......@@ -28,6 +28,7 @@ def log_once():
class AtariPlayer(RLEnvironment):
"""
A wrapper for atari emulator.
NOTE: will automatically restart when a real episode ends
"""
def __init__(self, rom_file, viz=0, height_range=(None,None),
frame_skip=4, image_shape=(84, 84), nullop_start=30,
......@@ -58,7 +59,7 @@ class AtariPlayer(RLEnvironment):
self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check
# manual.pdf suggests otherwise.
self.ale.setFloat('repeat_action_probability', 0.0)
# viz setup
......
......@@ -8,41 +8,7 @@ import numpy as np
from collections import deque
from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer', 'PreventStuckPlayer', 'LimitLengthPlayer']
class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images"""
def __init__(self, player, hist_len):
super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len)
s = self.player.current_state()
self.history.append(s)
def current_state(self):
assert len(self.history) != 0
diff_len = self.history.maxlen - len(self.history)
if diff_len == 0:
return np.concatenate(self.history, axis=2)
zeros = [np.zeros_like(self.history[0]) for k in range(diff_len)]
for k in self.history:
zeros.append(k)
return np.concatenate(zeros, axis=2)
def action(self, act):
r, isOver = self.player.action(act)
s = self.player.current_state()
self.history.append(s)
if isOver: # s would be a new episode
self.history.clear()
self.history.append(s)
return (r, isOver)
def restart_episode(self):
super(HistoryFramePlayer, self).restart_episode()
self.history.clear()
self.history.append(self.player.current_state())
__all__ = ['PreventStuckPlayer', 'LimitLengthPlayer', 'AutoRestartPlayer']
class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op)
......@@ -93,3 +59,12 @@ class LimitLengthPlayer(ProxyPlayer):
def restart_episode(self):
super(LimitLengthPlayer, self).restart_episode()
self.cnt = 0
class AutoRestartPlayer(ProxyPlayer):
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """
def action(self, act):
r, isOver = self.player.action(act)
if isOver:
self.player.restart_episode()
return r, isOver
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment