Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
49a21a29
Commit
49a21a29
authored
Jun 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
DQN release ready
parent
33353f33
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
56 additions
and
28 deletions
+56
-28
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+23
-18
examples/Atari2600/README.md
examples/Atari2600/README.md
+14
-0
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+19
-10
No files found.
examples/Atari2600/DQN.py
View file @
49a21a29
...
...
@@ -28,21 +28,15 @@ from tensorpack.callbacks import *
from
tensorpack.RL
import
*
"""
Implement DQN in:
Human-level Control Through Deep Reinforcement Learning
for atari games. Use the variants in:
Deep Reinforcement Learning with Double Q-learning.
"""
BATCH_SIZE
=
32
IMAGE_SIZE
=
(
84
,
84
)
FRAME_HISTORY
=
4
ACTION_REPEAT
=
3
ACTION_REPEAT
=
4
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
#HEIGHT_RANGE = (28, -8) # for pong
CHANNEL
=
FRAME_HISTORY
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
#HEIGHT_RANGE = (28, -8) # for pong
GAMMA
=
0.99
INIT_EXPLORATION
=
1
...
...
@@ -52,7 +46,7 @@ END_EXPLORATION = 0.1
MEMORY_SIZE
=
1e6
INIT_MEMORY_SIZE
=
50000
STEP_PER_EPOCH
=
10000
EVAL_EPISODE
=
10
0
EVAL_EPISODE
=
5
0
NUM_ACTIONS
=
None
ROM_FILE
=
None
...
...
@@ -63,10 +57,10 @@ def get_player(viz=False, train=False):
live_lost_as_eoe
=
train
)
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_num_actions
()
if
not
train
:
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
20000
)
return
pl
class
Model
(
ModelDesc
):
...
...
@@ -81,7 +75,7 @@ class Model(ModelDesc):
def
_get_DQN_prediction
(
self
,
image
,
is_training
):
""" image: [0,255]"""
image
=
image
/
255.0
with
argscope
(
Conv2D
,
nl
=
tf
.
nn
.
relu
,
use_bias
=
True
):
with
argscope
(
Conv2D
,
nl
=
PReLU
.
f
,
use_bias
=
True
):
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
1
)
l
=
MaxPooling
(
'pool0'
,
l
,
2
)
l
=
Conv2D
(
'conv1'
,
l
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
1
)
...
...
@@ -158,7 +152,11 @@ def play_one_episode(player, func, verbose=False):
return
np
.
mean
(
player
.
play_one_episode
(
f
))
def
play_model
(
model_path
):
player
=
get_player
(
0.013
)
import
uuid
dirname
=
'record'
+
str
(
uuid
.
uuid1
())[:
6
]
print
dirname
os
.
mkdir
(
dirname
)
player
=
get_player
(
viz
=
dirname
)
cfg
=
PredictConfig
(
model
=
Model
(),
input_data_mapping
=
[
0
],
...
...
@@ -168,8 +166,9 @@ def play_model(model_path):
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
break
def
eval_with_funcs
(
predict_funcs
):
def
eval_with_funcs
(
predict_funcs
,
nr_eval
=
EVAL_EPISODE
):
class
Worker
(
StoppableThread
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
...
...
@@ -181,7 +180,7 @@ def eval_with_funcs(predict_funcs):
score
=
play_one_episode
(
player
,
self
.
func
)
self
.
queue_put_stoppable
(
self
.
q
,
score
)
q
=
queue
.
Queue
(
maxsize
=
3
)
q
=
queue
.
Queue
(
maxsize
=
2
)
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict_funcs
]
for
k
in
threads
:
...
...
@@ -189,10 +188,11 @@ def eval_with_funcs(predict_funcs):
time
.
sleep
(
0.1
)
# avoid simulator bugs
stat
=
StatCounter
()
try
:
for
_
in
tqdm
(
range
(
EVAL_EPISODE
)):
for
_
in
tqdm
(
range
(
nr_eval
)):
r
=
q
.
get
()
stat
.
feed
(
r
)
finally
:
logger
.
info
(
"Waiting for all the workers to finish the last run..."
)
for
k
in
threads
:
k
.
stop
()
for
k
in
threads
:
k
.
join
()
return
(
stat
.
average
,
stat
.
max
)
...
...
@@ -214,9 +214,14 @@ class Evaluator(Callback):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
[
'state'
],
[
'fct/output'
])]
*
NR_PROC
self
.
eval_episode
=
EVAL_EPISODE
def
_trigger_epoch
(
self
):
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
)
t
=
time
.
time
()
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
,
nr_eval
=
self
.
eval_episode
)
t
=
time
.
time
()
-
t
if
t
>
8
*
60
:
# eval takes too long
self
.
eval_episode
=
int
(
self
.
eval_episode
*
0.89
)
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
max
)
...
...
@@ -240,7 +245,7 @@ def get_config():
reward_clip
=
(
-
1
,
1
),
history_len
=
FRAME_HISTORY
)
lr
=
tf
.
Variable
(
0.000
25
,
trainable
=
False
,
name
=
'learning_rate'
)
lr
=
tf
.
Variable
(
0.000
4
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
...
...
examples/Atari2600/README.md
0 → 100644
View file @
49a21a29
Implement DQN in:
**Human-level Control Through Deep Reinforcement Learning**
and Double-DQN in:
**Deep Reinforcement Learning with Double Q-learning**
To run:
```
./DQN.py --rom breakout.rom --gpu 0
```
A demo trained with Double-DQN is available at
[
youtube
](
https://youtu.be/o21mddZtE5Y
)
tensorpack/RL/atari.py
View file @
49a21a29
...
...
@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
import
time
import
os
import
time
,
os
import
cv2
from
collections
import
deque
import
six
from
six.moves
import
range
from
..utils
import
get_rng
,
logger
,
memoized
from
..utils.stat
import
StatCounter
...
...
@@ -37,7 +37,10 @@ class AtariPlayer(RLEnvironment):
:param frame_skip: skip every k frames and repeat the action
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable
:param viz: visualization to be done.
Set to 0 to disable.
Set to a positive number to be the delay between frames to show.
Set to a string to be a directory to store frames.
:param nullop_start: start with random number of null ops
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
"""
...
...
@@ -57,18 +60,24 @@ class AtariPlayer(RLEnvironment):
self
.
ale
.
setBool
(
'color_averaging'
,
False
)
# manual.pdf suggests otherwise. may need to check
self
.
ale
.
setFloat
(
'repeat_action_probability'
,
0.0
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
# viz setup
if
isinstance
(
viz
,
six
.
string_types
):
assert
os
.
path
.
isdir
(
viz
),
viz
self
.
ale
.
setString
(
'record_screen_dir'
,
viz
)
viz
=
0
if
isinstance
(
viz
,
int
):
viz
=
float
(
viz
)
self
.
viz
=
viz
self
.
romname
=
os
.
path
.
basename
(
rom_file
)
if
self
.
viz
and
isinstance
(
self
.
viz
,
float
):
self
.
windowname
=
os
.
path
.
basename
(
rom_file
)
cv2
.
startWindowThread
()
cv2
.
namedWindow
(
self
.
romname
)
cv2
.
namedWindow
(
self
.
windowname
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
self
.
live_lost_as_eoe
=
live_lost_as_eoe
self
.
frame_skip
=
frame_skip
...
...
@@ -95,7 +104,7 @@ class AtariPlayer(RLEnvironment):
ret
=
np
.
maximum
(
ret
,
self
.
last_raw_screen
)
if
self
.
viz
:
if
isinstance
(
self
.
viz
,
float
):
cv2
.
imshow
(
self
.
rom
name
,
ret
)
cv2
.
imshow
(
self
.
window
name
,
ret
)
time
.
sleep
(
self
.
viz
)
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
# 0.299,0.587.0.114. same as rgb2y in torch/image
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment