Commit df869711 authored by Yuxin Wu's avatar Yuxin Wu

[DQN] atari env info consistent with gym settings

parent f9f1e437
......@@ -23,7 +23,7 @@ Claimed performance in the paper can be reproduced, on several games I've tested
On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout game.
Batch-A3C implementation only took <2 hours.
Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on (Maxwell) TitanX.
Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on (Maxwell) TitanX.
## How to use
......
......@@ -97,8 +97,6 @@ class AtariPlayer(gym.Env):
self.frame_skip = frame_skip
self.nullop_start = nullop_start
self.current_episode_score = StatCounter()
self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width))
......@@ -131,7 +129,6 @@ class AtariPlayer(gym.Env):
return ret.astype('uint8') # to save some memory
def _restart_episode(self):
self.current_episode_score.reset()
with _ALE_LOCK:
self.ale.reset_game()
......@@ -160,12 +157,11 @@ class AtariPlayer(gym.Env):
(self.live_lost_as_eoe and newlives < oldlives):
break
self.current_episode_score.feed(r)
trueIsOver = isOver = self.ale.game_over()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
info = {'score': self.current_episode_score.sum, 'gameOver': trueIsOver}
info = {'ale.lives': newlives}
return self._current_state(), r, isOver, info
......
......@@ -155,6 +155,7 @@ class ExpReplay(DataFlow, Callback):
self.mem = ReplayMemory(memory_size, state_shape, history_len)
self._current_ob = self.player.reset()
self._player_scores = StatCounter()
self._current_game_score = StatCounter()
def get_simulator_thread(self):
# spawn a separate thread to run policy
......@@ -202,9 +203,11 @@ class ExpReplay(DataFlow, Callback):
q_values = self.predictor(history[None, :, :, :])[0][0] # this is the bottleneck
act = np.argmax(q_values)
self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward)
if isOver:
if info['gameOver']: # only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(info['score'])
if info['ale.lives'] == 0: # only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
self.player.reset()
self.mem.append(Experience(old_s, act, reward, isOver))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment