Commit 64707bfa authored by Yuxin Wu's avatar Yuxin Wu

better atari env

parent 80722088
......@@ -8,6 +8,7 @@ import time
import os
import cv2
from collections import deque
from six.moves import range
from ..utils import get_rng, logger
from ..utils.stat import StatCounter
......@@ -25,7 +26,7 @@ class AtariPlayer(RLEnvironment):
A wrapper for atari emulator.
"""
def __init__(self, rom_file, viz=0, height_range=(None,None),
frame_skip=4, image_shape=(84, 84)):
frame_skip=4, image_shape=(84, 84), nullop_start=30):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
......@@ -37,9 +38,12 @@ class AtariPlayer(RLEnvironment):
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(0, 1000))
self.ale.setInt("frame_skip", frame_skip)
self.ale.setBool('color_averaging', True)
self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check
self.ale.setFloat('repeat_action_probability', 0.0)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
......@@ -51,8 +55,9 @@ class AtariPlayer(RLEnvironment):
if self.viz and isinstance(self.viz, float):
cv2.startWindowThread()
cv2.namedWindow(self.romname)
self.framenum = 0
self.frame_skip = frame_skip
self.nullop_start = nullop_start
self.height_range = height_range
self.image_shape = image_shape
self.current_episode_score = StatCounter()
......@@ -63,8 +68,7 @@ class AtariPlayer(RLEnvironment):
"""
:returns: the current 3-channel image
"""
m = np.zeros(self.height * self.width * 3, dtype=np.uint8)
self.ale.getScreenRGB(m)
m = self.ale.getScreenRGB()
return m.reshape((self.height, self.width, 3))
def current_state(self):
......@@ -72,15 +76,15 @@ class AtariPlayer(RLEnvironment):
:returns: a gray-scale (h, w, 1) image
"""
ret = self._grab_raw_image()
# max-pooled over the last screen
ret = np.maximum(ret, self.last_raw_screen)
if self.viz:
if isinstance(self.viz, float):
cv2.imshow(self.romname, ret)
time.sleep(self.viz)
else:
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1
ret = ret[self.height_range[0]:self.height_range[1],:]
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape)
ret = np.expand_dims(ret, axis=2)
return ret
......@@ -95,17 +99,34 @@ class AtariPlayer(RLEnvironment):
self.current_episode_score.reset()
self.ale.reset_game()
# random null-ops start
n = self.rng.randint(self.nullop_start)
for k in range(n):
if k == n - 1:
self.last_raw_screen = self._grab_raw_image()
self.ale.act(0)
def action(self, act):
"""
:param act: an index of the action
:returns: (reward, isOver)
"""
r = self.ale.act(self.actions[act])
oldlives = self.ale.lives()
r = 0
for k in range(self.frame_skip):
if k == self.frame_skip - 1:
self.last_raw_screen = self._grab_raw_image()
r += self.ale.act(self.actions[act])
newlives = self.ale.lives()
if self.ale.game_over() or newlives < oldlives:
break
self.current_episode_score.feed(r)
isOver = self.ale.game_over()
if isOver:
self.stats['score'].append(self.current_episode_score.sum)
self._reset()
isOver = isOver or newlives < oldlives
return (r, isOver)
def get_stat(self):
......@@ -118,7 +139,7 @@ class AtariPlayer(RLEnvironment):
if __name__ == '__main__':
import sys
a = AtariPlayer(sys.argv[1],
viz=0.01, height_range=(28,-8))
viz=0.03, height_range=(28,-8))
num = a.get_num_actions()
rng = get_rng(num)
import time
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment