Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
df869711
Commit
df869711
authored
Nov 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[DQN] atari env info consistent with gym settings
parent
f9f1e437
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
8 deletions
+7
-8
examples/DeepQNetwork/README.md
examples/DeepQNetwork/README.md
+1
-1
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+1
-5
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+5
-2
No files found.
examples/DeepQNetwork/README.md
View file @
df869711
...
@@ -23,7 +23,7 @@ Claimed performance in the paper can be reproduced, on several games I've tested
...
@@ -23,7 +23,7 @@ Claimed performance in the paper can be reproduced, on several games I've tested
On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout game.
On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout game.
Batch-A3C implementation only took <2 hours.
Batch-A3C implementation only took <2 hours.
Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on (Maxwell) TitanX.
Double-DQN
with nature paper setting
runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on (Maxwell) TitanX.
## How to use
## How to use
...
...
examples/DeepQNetwork/atari.py
View file @
df869711
...
@@ -97,8 +97,6 @@ class AtariPlayer(gym.Env):
...
@@ -97,8 +97,6 @@ class AtariPlayer(gym.Env):
self
.
frame_skip
=
frame_skip
self
.
frame_skip
=
frame_skip
self
.
nullop_start
=
nullop_start
self
.
nullop_start
=
nullop_start
self
.
current_episode_score
=
StatCounter
()
self
.
action_space
=
spaces
.
Discrete
(
len
(
self
.
actions
))
self
.
action_space
=
spaces
.
Discrete
(
len
(
self
.
actions
))
self
.
observation_space
=
spaces
.
Box
(
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
self
.
height
,
self
.
width
))
low
=
0
,
high
=
255
,
shape
=
(
self
.
height
,
self
.
width
))
...
@@ -131,7 +129,6 @@ class AtariPlayer(gym.Env):
...
@@ -131,7 +129,6 @@ class AtariPlayer(gym.Env):
return
ret
.
astype
(
'uint8'
)
# to save some memory
return
ret
.
astype
(
'uint8'
)
# to save some memory
def
_restart_episode
(
self
):
def
_restart_episode
(
self
):
self
.
current_episode_score
.
reset
()
with
_ALE_LOCK
:
with
_ALE_LOCK
:
self
.
ale
.
reset_game
()
self
.
ale
.
reset_game
()
...
@@ -160,12 +157,11 @@ class AtariPlayer(gym.Env):
...
@@ -160,12 +157,11 @@ class AtariPlayer(gym.Env):
(
self
.
live_lost_as_eoe
and
newlives
<
oldlives
):
(
self
.
live_lost_as_eoe
and
newlives
<
oldlives
):
break
break
self
.
current_episode_score
.
feed
(
r
)
trueIsOver
=
isOver
=
self
.
ale
.
game_over
()
trueIsOver
=
isOver
=
self
.
ale
.
game_over
()
if
self
.
live_lost_as_eoe
:
if
self
.
live_lost_as_eoe
:
isOver
=
isOver
or
newlives
<
oldlives
isOver
=
isOver
or
newlives
<
oldlives
info
=
{
'
score'
:
self
.
current_episode_score
.
sum
,
'gameOver'
:
trueIsOver
}
info
=
{
'
ale.lives'
:
newlives
}
return
self
.
_current_state
(),
r
,
isOver
,
info
return
self
.
_current_state
(),
r
,
isOver
,
info
...
...
examples/DeepQNetwork/expreplay.py
View file @
df869711
...
@@ -155,6 +155,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -155,6 +155,7 @@ class ExpReplay(DataFlow, Callback):
self
.
mem
=
ReplayMemory
(
memory_size
,
state_shape
,
history_len
)
self
.
mem
=
ReplayMemory
(
memory_size
,
state_shape
,
history_len
)
self
.
_current_ob
=
self
.
player
.
reset
()
self
.
_current_ob
=
self
.
player
.
reset
()
self
.
_player_scores
=
StatCounter
()
self
.
_player_scores
=
StatCounter
()
self
.
_current_game_score
=
StatCounter
()
def
get_simulator_thread
(
self
):
def
get_simulator_thread
(
self
):
# spawn a separate thread to run policy
# spawn a separate thread to run policy
...
@@ -202,9 +203,11 @@ class ExpReplay(DataFlow, Callback):
...
@@ -202,9 +203,11 @@ class ExpReplay(DataFlow, Callback):
q_values
=
self
.
predictor
(
history
[
None
,
:,
:,
:])[
0
][
0
]
# this is the bottleneck
q_values
=
self
.
predictor
(
history
[
None
,
:,
:,
:])[
0
][
0
]
# this is the bottleneck
act
=
np
.
argmax
(
q_values
)
act
=
np
.
argmax
(
q_values
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_game_score
.
feed
(
reward
)
if
isOver
:
if
isOver
:
if
info
[
'gameOver'
]:
# only record score when a whole game is over (not when an episode is over)
if
info
[
'ale.lives'
]
==
0
:
# only record score when a whole game is over (not when an episode is over)
self
.
_player_scores
.
feed
(
info
[
'score'
])
self
.
_player_scores
.
feed
(
self
.
_current_game_score
.
sum
)
self
.
_current_game_score
.
reset
()
self
.
player
.
reset
()
self
.
player
.
reset
()
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment