Commit 64707bfa authored by Yuxin Wu's avatar Yuxin Wu

better atari env

parent 80722088
...@@ -8,6 +8,7 @@ import time ...@@ -8,6 +8,7 @@ import time
import os import os
import cv2 import cv2
from collections import deque from collections import deque
from six.moves import range
from ..utils import get_rng, logger from ..utils import get_rng, logger
from ..utils.stat import StatCounter from ..utils.stat import StatCounter
...@@ -25,7 +26,7 @@ class AtariPlayer(RLEnvironment): ...@@ -25,7 +26,7 @@ class AtariPlayer(RLEnvironment):
A wrapper for atari emulator. A wrapper for atari emulator.
""" """
def __init__(self, rom_file, viz=0, height_range=(None,None), def __init__(self, rom_file, viz=0, height_range=(None,None),
frame_skip=4, image_shape=(84, 84)): frame_skip=4, image_shape=(84, 84), nullop_start=30):
""" """
:param rom_file: path to the rom :param rom_file: path to the rom
:param frame_skip: skip every k frames :param frame_skip: skip every k frames
...@@ -37,9 +38,12 @@ class AtariPlayer(RLEnvironment): ...@@ -37,9 +38,12 @@ class AtariPlayer(RLEnvironment):
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(0, 1000)) self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setInt("frame_skip", frame_skip) self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', True) self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check
self.ale.setFloat('repeat_action_probability', 0.0)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet() self.actions = self.ale.getMinimalActionSet()
...@@ -51,8 +55,9 @@ class AtariPlayer(RLEnvironment): ...@@ -51,8 +55,9 @@ class AtariPlayer(RLEnvironment):
if self.viz and isinstance(self.viz, float): if self.viz and isinstance(self.viz, float):
cv2.startWindowThread() cv2.startWindowThread()
cv2.namedWindow(self.romname) cv2.namedWindow(self.romname)
self.framenum = 0
self.frame_skip = frame_skip
self.nullop_start = nullop_start
self.height_range = height_range self.height_range = height_range
self.image_shape = image_shape self.image_shape = image_shape
self.current_episode_score = StatCounter() self.current_episode_score = StatCounter()
...@@ -63,8 +68,7 @@ class AtariPlayer(RLEnvironment): ...@@ -63,8 +68,7 @@ class AtariPlayer(RLEnvironment):
""" """
:returns: the current 3-channel image :returns: the current 3-channel image
""" """
m = np.zeros(self.height * self.width * 3, dtype=np.uint8) m = self.ale.getScreenRGB()
self.ale.getScreenRGB(m)
return m.reshape((self.height, self.width, 3)) return m.reshape((self.height, self.width, 3))
def current_state(self): def current_state(self):
...@@ -72,15 +76,15 @@ class AtariPlayer(RLEnvironment): ...@@ -72,15 +76,15 @@ class AtariPlayer(RLEnvironment):
:returns: a gray-scale (h, w, 1) image :returns: a gray-scale (h, w, 1) image
""" """
ret = self._grab_raw_image() ret = self._grab_raw_image()
# max-pooled over the last screen
ret = np.maximum(ret, self.last_raw_screen)
if self.viz: if self.viz:
if isinstance(self.viz, float): if isinstance(self.viz, float):
cv2.imshow(self.romname, ret) cv2.imshow(self.romname, ret)
time.sleep(self.viz) time.sleep(self.viz)
else:
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1
ret = ret[self.height_range[0]:self.height_range[1],:] ret = ret[self.height_range[0]:self.height_range[1],:]
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0] # 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape) ret = cv2.resize(ret, self.image_shape)
ret = np.expand_dims(ret, axis=2) ret = np.expand_dims(ret, axis=2)
return ret return ret
...@@ -95,17 +99,34 @@ class AtariPlayer(RLEnvironment): ...@@ -95,17 +99,34 @@ class AtariPlayer(RLEnvironment):
self.current_episode_score.reset() self.current_episode_score.reset()
self.ale.reset_game() self.ale.reset_game()
# random null-ops start
n = self.rng.randint(self.nullop_start)
for k in range(n):
if k == n - 1:
self.last_raw_screen = self._grab_raw_image()
self.ale.act(0)
def action(self, act): def action(self, act):
""" """
:param act: an index of the action :param act: an index of the action
:returns: (reward, isOver) :returns: (reward, isOver)
""" """
r = self.ale.act(self.actions[act]) oldlives = self.ale.lives()
r = 0
for k in range(self.frame_skip):
if k == self.frame_skip - 1:
self.last_raw_screen = self._grab_raw_image()
r += self.ale.act(self.actions[act])
newlives = self.ale.lives()
if self.ale.game_over() or newlives < oldlives:
break
self.current_episode_score.feed(r) self.current_episode_score.feed(r)
isOver = self.ale.game_over() isOver = self.ale.game_over()
if isOver: if isOver:
self.stats['score'].append(self.current_episode_score.sum) self.stats['score'].append(self.current_episode_score.sum)
self._reset() self._reset()
isOver = isOver or newlives < oldlives
return (r, isOver) return (r, isOver)
def get_stat(self): def get_stat(self):
...@@ -118,7 +139,7 @@ class AtariPlayer(RLEnvironment): ...@@ -118,7 +139,7 @@ class AtariPlayer(RLEnvironment):
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
a = AtariPlayer(sys.argv[1], a = AtariPlayer(sys.argv[1],
viz=0.01, height_range=(28,-8)) viz=0.03, height_range=(28,-8))
num = a.get_num_actions() num = a.get_num_actions()
rng = get_rng(num) rng = get_rng(num)
import time import time
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment