Commit 9319b978 authored by Yuxin Wu's avatar Yuxin Wu

height range for atari

parent 6e1f395d
......@@ -36,6 +36,7 @@ IMAGE_SIZE = 84
NUM_ACTIONS = None
FRAME_HISTORY = 4
ACTION_REPEAT = 3
HEIGHT_RANGE = (36, 204) # for breakout
GAMMA = 0.99
BATCH_SIZE = 32
......@@ -154,7 +155,7 @@ def play_one_episode(player, func, verbose=False):
return tot_reward
def play_model(model_path, romfile):
player = AtariPlayer(AtariDriver(romfile, viz=0.01),
player = AtariPlayer(AtariDriver(romfile, viz=0.01, height_range=HEIGHT_RANGE),
action_repeat=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
......@@ -184,7 +185,7 @@ def eval_model_multiprocess(model_path, romfile):
self.outq = outqueue
def run(self):
player = AtariPlayer(AtariDriver(romfile, viz=0),
player = AtariPlayer(AtariDriver(romfile, viz=0, height_range=HEIGHT_RANGE),
action_repeat=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
......@@ -224,7 +225,7 @@ def get_config(romfile):
os.path.join('train_log', basename[:basename.rfind('.')]))
M = Model()
driver = AtariDriver(romfile)
driver = AtariDriver(romfile, height_range=HEIGHT_RANGE)
global NUM_ACTIONS
NUM_ACTIONS = driver.get_num_actions()
......
......@@ -22,7 +22,8 @@ class AtariDriver(object):
"""
A wrapper for atari emulator.
"""
def __init__(self, rom_file, frame_skip=1, viz=0):
def __init__(self, rom_file,
frame_skip=1, viz=0, height_range=(None,None)):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
......@@ -48,6 +49,7 @@ class AtariDriver(object):
self._reset()
self.last_image = self._grab_raw_image()
self.framenum = 0
self.height_range = height_range
def _grab_raw_image(self):
"""
......@@ -64,14 +66,15 @@ class AtariDriver(object):
now = self._grab_raw_image()
ret = np.maximum(now, self.last_image)
self.last_image = now
if self.viz and isinstance(self.viz, float):
ret = ret[self.height_range[0]:self.height_range[1],:] # several online repos all use this
if self.viz:
if isinstance(self.viz, float):
cv2.imshow(self.romname, ret)
time.sleep(self.viz)
elif self.viz:
else:
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
ret = ret[36:204,:] # several online repos all use this
return ret
def get_num_actions(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment