Commit 9319b978 authored by Yuxin Wu's avatar Yuxin Wu

height range for atari

parent 6e1f395d
...@@ -36,6 +36,7 @@ IMAGE_SIZE = 84 ...@@ -36,6 +36,7 @@ IMAGE_SIZE = 84
NUM_ACTIONS = None NUM_ACTIONS = None
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 3 ACTION_REPEAT = 3
HEIGHT_RANGE = (36, 204) # for breakout
GAMMA = 0.99 GAMMA = 0.99
BATCH_SIZE = 32 BATCH_SIZE = 32
...@@ -154,7 +155,7 @@ def play_one_episode(player, func, verbose=False): ...@@ -154,7 +155,7 @@ def play_one_episode(player, func, verbose=False):
return tot_reward return tot_reward
def play_model(model_path, romfile): def play_model(model_path, romfile):
player = AtariPlayer(AtariDriver(romfile, viz=0.01), player = AtariPlayer(AtariDriver(romfile, viz=0.01, height_range=HEIGHT_RANGE),
action_repeat=ACTION_REPEAT) action_repeat=ACTION_REPEAT)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions() NUM_ACTIONS = player.driver.get_num_actions()
...@@ -184,7 +185,7 @@ def eval_model_multiprocess(model_path, romfile): ...@@ -184,7 +185,7 @@ def eval_model_multiprocess(model_path, romfile):
self.outq = outqueue self.outq = outqueue
def run(self): def run(self):
player = AtariPlayer(AtariDriver(romfile, viz=0), player = AtariPlayer(AtariDriver(romfile, viz=0, height_range=HEIGHT_RANGE),
action_repeat=ACTION_REPEAT) action_repeat=ACTION_REPEAT)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions() NUM_ACTIONS = player.driver.get_num_actions()
...@@ -224,7 +225,7 @@ def get_config(romfile): ...@@ -224,7 +225,7 @@ def get_config(romfile):
os.path.join('train_log', basename[:basename.rfind('.')])) os.path.join('train_log', basename[:basename.rfind('.')]))
M = Model() M = Model()
driver = AtariDriver(romfile) driver = AtariDriver(romfile, height_range=HEIGHT_RANGE)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = driver.get_num_actions() NUM_ACTIONS = driver.get_num_actions()
......
...@@ -22,7 +22,8 @@ class AtariDriver(object): ...@@ -22,7 +22,8 @@ class AtariDriver(object):
""" """
A wrapper for atari emulator. A wrapper for atari emulator.
""" """
def __init__(self, rom_file, frame_skip=1, viz=0): def __init__(self, rom_file,
frame_skip=1, viz=0, height_range=(None,None)):
""" """
:param rom_file: path to the rom :param rom_file: path to the rom
:param frame_skip: skip every k frames :param frame_skip: skip every k frames
...@@ -48,6 +49,7 @@ class AtariDriver(object): ...@@ -48,6 +49,7 @@ class AtariDriver(object):
self._reset() self._reset()
self.last_image = self._grab_raw_image() self.last_image = self._grab_raw_image()
self.framenum = 0 self.framenum = 0
self.height_range = height_range
def _grab_raw_image(self): def _grab_raw_image(self):
""" """
...@@ -64,14 +66,15 @@ class AtariDriver(object): ...@@ -64,14 +66,15 @@ class AtariDriver(object):
now = self._grab_raw_image() now = self._grab_raw_image()
ret = np.maximum(now, self.last_image) ret = np.maximum(now, self.last_image)
self.last_image = now self.last_image = now
if self.viz and isinstance(self.viz, float): ret = ret[self.height_range[0]:self.height_range[1],:] # several online repos all use this
cv2.imshow(self.romname, ret) if self.viz:
time.sleep(self.viz) if isinstance(self.viz, float):
elif self.viz: cv2.imshow(self.romname, ret)
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret) time.sleep(self.viz)
self.framenum += 1 else:
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0] ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
ret = ret[36:204,:] # several online repos all use this
return ret return ret
def get_num_actions(self): def get_num_actions(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment