Commit 5c0d2c9b authored by Yuxin Wu's avatar Yuxin Wu

atari viz

parent 7346f13b
......@@ -5,6 +5,7 @@
from ale_python_interface import ALEInterface
import numpy as np
import time
import os
import cv2
from .utils import get_rng
......@@ -15,16 +16,16 @@ class AtariDriver(object):
"""
A driver for atari games.
"""
def __init__(self, rom_file, frame_skip=1, viz=False):
def __init__(self, rom_file, frame_skip=1, viz=0):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param viz: visualize the game while running
:param viz: the delay. visualize the game while running. 0 to disable
"""
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(99999))
self.ale.setInt("random_seed", self.rng.randint(214))
self.ale.setInt("frame_skip", frame_skip)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
......@@ -56,6 +57,7 @@ class AtariDriver(object):
self.last_image = now
if self.viz:
cv2.imshow(self.romname, ret)
time.sleep(self.viz)
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment