Commit 364fe347 authored by Yuxin Wu's avatar Yuxin Wu

lock for atari simulator init

parent f461ed2e
...@@ -52,7 +52,6 @@ def eval_with_funcs(predict_funcs, nr_eval): ...@@ -52,7 +52,6 @@ def eval_with_funcs(predict_funcs, nr_eval):
for k in threads: for k in threads:
k.start() k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter() stat = StatCounter()
try: try:
for _ in tqdm(range(nr_eval)): for _ in tqdm(range(nr_eval)):
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import time, os import time, os
import cv2 import cv2
from collections import deque from collections import deque
import threading
import six import six
from six.moves import range from six.moves import range
from ..utils import get_rng, logger, memoized, get_dataset_dir from ..utils import get_rng, logger, memoized, get_dataset_dir
...@@ -25,6 +26,8 @@ __all__ = ['AtariPlayer'] ...@@ -25,6 +26,8 @@ __all__ = ['AtariPlayer']
def log_once(): def log_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!") logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
_ALE_LOCK = threading.Lock()
class AtariPlayer(RLEnvironment): class AtariPlayer(RLEnvironment):
""" """
A wrapper for atari emulator. A wrapper for atari emulator.
...@@ -50,36 +53,38 @@ class AtariPlayer(RLEnvironment): ...@@ -50,36 +53,38 @@ class AtariPlayer(RLEnvironment):
rom_file = os.path.join(get_dataset_dir('atari_rom'), rom_file) rom_file = os.path.join(get_dataset_dir('atari_rom'), rom_file)
assert os.path.isfile(rom_file), "rom {} not found".format(rom_file) assert os.path.isfile(rom_file), "rom {} not found".format(rom_file)
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False)
try: try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning) ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError: except AttributeError:
log_once() log_once()
self.ale.setInt("frame_skip", 1) # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
self.ale.setBool('color_averaging', False) with _ALE_LOCK:
# manual.pdf suggests otherwise. self.ale = ALEInterface()
self.ale.setFloat('repeat_action_probability', 0.0) self.rng = get_rng(self)
# viz setup self.ale.setInt("random_seed", self.rng.randint(0, 10000))
if isinstance(viz, six.string_types): self.ale.setBool("showinfo", False)
assert os.path.isdir(viz), viz
self.ale.setString('record_screen_dir', viz) self.ale.setInt("frame_skip", 1)
viz = 0 self.ale.setBool('color_averaging', False)
if isinstance(viz, int): # manual.pdf suggests otherwise.
viz = float(viz) self.ale.setFloat('repeat_action_probability', 0.0)
self.viz = viz
if self.viz and isinstance(self.viz, float): # viz setup
self.windowname = os.path.basename(rom_file) if isinstance(viz, six.string_types):
cv2.startWindowThread() assert os.path.isdir(viz), viz
cv2.namedWindow(self.windowname) self.ale.setString('record_screen_dir', viz)
viz = 0
self.ale.loadROM(rom_file) if isinstance(viz, int):
viz = float(viz)
self.viz = viz
if self.viz and isinstance(self.viz, float):
self.windowname = os.path.basename(rom_file)
cv2.startWindowThread()
cv2.namedWindow(self.windowname)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet() self.actions = self.ale.getMinimalActionSet()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment