Commit 5c0d2c9b authored by Yuxin Wu's avatar Yuxin Wu

atari viz

parent 7346f13b
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from ale_python_interface import ALEInterface from ale_python_interface import ALEInterface
import numpy as np import numpy as np
import time
import os import os
import cv2 import cv2
from .utils import get_rng from .utils import get_rng
...@@ -15,16 +16,16 @@ class AtariDriver(object): ...@@ -15,16 +16,16 @@ class AtariDriver(object):
""" """
A driver for atari games. A driver for atari games.
""" """
def __init__(self, rom_file, frame_skip=1, viz=False): def __init__(self, rom_file, frame_skip=1, viz=0):
""" """
:param rom_file: path to the rom :param rom_file: path to the rom
:param frame_skip: skip every k frames :param frame_skip: skip every k frames
:param viz: visualize the game while running :param viz: the delay. visualize the game while running. 0 to disable
""" """
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(99999)) self.ale.setInt("random_seed", self.rng.randint(214))
self.ale.setInt("frame_skip", frame_skip) self.ale.setInt("frame_skip", frame_skip)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
...@@ -56,6 +57,7 @@ class AtariDriver(object): ...@@ -56,6 +57,7 @@ class AtariDriver(object):
self.last_image = now self.last_image = now
if self.viz: if self.viz:
cv2.imshow(self.romname, ret) cv2.imshow(self.romname, ret)
time.sleep(self.viz)
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0] ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
return ret return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment