Commit 5fd47e6d authored by Yuxin Wu's avatar Yuxin Wu

don't use eoe in eval

parent 64707bfa
...@@ -50,8 +50,10 @@ EVAL_EPISODE = 100 ...@@ -50,8 +50,10 @@ EVAL_EPISODE = 100
NUM_ACTIONS = None NUM_ACTIONS = None
ROM_FILE = None ROM_FILE = None
def get_player(viz=False): def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, viz=viz, height_range=HEIGHT_RANGE, frame_skip=ACTION_REPEAT) player = AtariPlayer(ROM_FILE, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz,
live_lost_as_eoe=train)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = pl.get_num_actions() NUM_ACTIONS = pl.get_num_actions()
return pl return pl
...@@ -220,7 +222,7 @@ def get_config(): ...@@ -220,7 +222,7 @@ def get_config():
M = Model() M = Model()
dataset_train = ExpReplay( dataset_train = ExpReplay(
predictor=current_predictor, predictor=current_predictor,
player=get_player(), player=get_player(train=True),
num_actions=NUM_ACTIONS, num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
......
...@@ -26,13 +26,16 @@ class AtariPlayer(RLEnvironment): ...@@ -26,13 +26,16 @@ class AtariPlayer(RLEnvironment):
A wrapper for atari emulator. A wrapper for atari emulator.
""" """
def __init__(self, rom_file, viz=0, height_range=(None,None), def __init__(self, rom_file, viz=0, height_range=(None,None),
frame_skip=4, image_shape=(84, 84), nullop_start=30): frame_skip=4, image_shape=(84, 84), nullop_start=30,
live_lost_as_eoe=True):
""" """
:param rom_file: path to the rom :param rom_file: path to the rom
:param frame_skip: skip every k frames :param frame_skip: skip every k frames
:param image_shape: (w, h) :param image_shape: (w, h)
:param height_range: (h1, h2) to cut :param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable :param viz: the delay. visualize the game while running. 0 to disable
:param nullop_start: start with random number of null ops
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
""" """
super(AtariPlayer, self).__init__() super(AtariPlayer, self).__init__()
self.ale = ALEInterface() self.ale = ALEInterface()
...@@ -45,6 +48,8 @@ class AtariPlayer(RLEnvironment): ...@@ -45,6 +48,8 @@ class AtariPlayer(RLEnvironment):
self.ale.setFloat('repeat_action_probability', 0.0) self.ale.setFloat('repeat_action_probability', 0.0)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet() self.actions = self.ale.getMinimalActionSet()
...@@ -56,6 +61,7 @@ class AtariPlayer(RLEnvironment): ...@@ -56,6 +61,7 @@ class AtariPlayer(RLEnvironment):
cv2.startWindowThread() cv2.startWindowThread()
cv2.namedWindow(self.romname) cv2.namedWindow(self.romname)
self.live_lost_as_eoe = live_lost_as_eoe
self.frame_skip = frame_skip self.frame_skip = frame_skip
self.nullop_start = nullop_start self.nullop_start = nullop_start
self.height_range = height_range self.height_range = height_range
...@@ -101,6 +107,7 @@ class AtariPlayer(RLEnvironment): ...@@ -101,6 +107,7 @@ class AtariPlayer(RLEnvironment):
# random null-ops start # random null-ops start
n = self.rng.randint(self.nullop_start) n = self.rng.randint(self.nullop_start)
self.last_raw_screen = self._grab_raw_image()
for k in range(n): for k in range(n):
if k == n - 1: if k == n - 1:
self.last_raw_screen = self._grab_raw_image() self.last_raw_screen = self._grab_raw_image()
...@@ -118,7 +125,8 @@ class AtariPlayer(RLEnvironment): ...@@ -118,7 +125,8 @@ class AtariPlayer(RLEnvironment):
self.last_raw_screen = self._grab_raw_image() self.last_raw_screen = self._grab_raw_image()
r += self.ale.act(self.actions[act]) r += self.ale.act(self.actions[act])
newlives = self.ale.lives() newlives = self.ale.lives()
if self.ale.game_over() or newlives < oldlives: if self.ale.game_over() or \
(self.live_lost_as_eoe and newlives < oldlives):
break break
self.current_episode_score.feed(r) self.current_episode_score.feed(r)
...@@ -126,7 +134,8 @@ class AtariPlayer(RLEnvironment): ...@@ -126,7 +134,8 @@ class AtariPlayer(RLEnvironment):
if isOver: if isOver:
self.stats['score'].append(self.current_episode_score.sum) self.stats['score'].append(self.current_episode_score.sum)
self._reset() self._reset()
isOver = isOver or newlives < oldlives if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
return (r, isOver) return (r, isOver)
def get_stat(self): def get_stat(self):
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import multiprocessing, threading import multiprocessing, threading
import tensorflow as tf import tensorflow as tf
import six
from six.moves import queue, range from six.moves import queue, range
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment