Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d3167ba3
Commit
d3167ba3
authored
Nov 08, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add dueling
parent
4f3d4e27
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
26 deletions
+32
-26
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+29
-25
examples/Atari2600/README.md
examples/Atari2600/README.md
+3
-1
No files found.
examples/Atari2600/DQN.py
View file @
d3167ba3
...
@@ -23,13 +23,13 @@ import common
...
@@ -23,13 +23,13 @@ import common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
atari
import
AtariPlayer
from
atari
import
AtariPlayer
METHOD
=
[
'DQN'
,
'Double'
,
'Dueling'
][
1
]
BATCH_SIZE
=
64
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
ACTION_REPEAT
=
4
ACTION_REPEAT
=
4
HEIGHT_RANGE
=
(
None
,
None
)
#HEIGHT_RANGE = (36, 204) # for breakout
#HEIGHT_RANGE = (28, -8) # for pong
CHANNEL
=
FRAME_HISTORY
CHANNEL
=
FRAME_HISTORY
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
...
@@ -50,9 +50,8 @@ NUM_ACTIONS = None
...
@@ -50,9 +50,8 @@ NUM_ACTIONS = None
ROM_FILE
=
None
ROM_FILE
=
None
def
get_player
(
viz
=
False
,
train
=
False
):
def
get_player
(
viz
=
False
,
train
=
False
):
pl
=
AtariPlayer
(
ROM_FILE
,
height_range
=
HEIGHT_RANGE
,
pl
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
frame_skip
=
ACTION_REPEAT
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
live_lost_as_eoe
=
train
)
global
NUM_ACTIONS
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_action_space
()
.
num_actions
()
NUM_ACTIONS
=
pl
.
get_action_space
()
.
num_actions
()
if
not
train
:
if
not
train
:
...
@@ -76,7 +75,7 @@ class Model(ModelDesc):
...
@@ -76,7 +75,7 @@ class Model(ModelDesc):
""" image: [0,255]"""
""" image: [0,255]"""
image
=
image
/
255.0
image
=
image
/
255.0
with
argscope
(
Conv2D
,
nl
=
PReLU
.
f
,
use_bias
=
True
):
with
argscope
(
Conv2D
,
nl
=
PReLU
.
f
,
use_bias
=
True
):
return
(
LinearWrap
(
image
)
l
=
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
,
out_channel
=
32
,
kernel_shape
=
5
)
.
Conv2D
(
'conv0'
,
out_channel
=
32
,
kernel_shape
=
5
)
.
MaxPooling
(
'pool0'
,
2
)
.
MaxPooling
(
'pool0'
,
2
)
.
Conv2D
(
'conv1'
,
out_channel
=
32
,
kernel_shape
=
5
)
.
Conv2D
(
'conv1'
,
out_channel
=
32
,
kernel_shape
=
5
)
...
@@ -90,8 +89,14 @@ class Model(ModelDesc):
...
@@ -90,8 +89,14 @@ class Model(ModelDesc):
#.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
#.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
#.Conv2D('conv2', out_channel=64, kernel_shape=3)
#.Conv2D('conv2', out_channel=64, kernel_shape=3)
.
FullyConnected
(
'fc0'
,
512
,
nl
=
lambda
x
,
name
:
LeakyReLU
.
f
(
x
,
0.01
,
name
))
.
FullyConnected
(
'fc0'
,
512
,
nl
=
lambda
x
,
name
:
LeakyReLU
.
f
(
x
,
0.01
,
name
))())
.
FullyConnected
(
'fct'
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)())
if
METHOD
!=
'Dueling'
:
Q
=
FullyConnected
(
'fct'
,
l
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)
else
:
V
=
FullyConnected
(
'fctV'
,
l
,
1
,
nl
=
tf
.
identity
)
As
=
FullyConnected
(
'fctA'
,
l
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)
Q
=
tf
.
add
(
As
,
V
-
tf
.
reduce_mean
(
As
,
1
,
keep_dims
=
True
))
return
tf
.
identity
(
Q
,
name
=
'Qvalue'
)
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
...
@@ -105,22 +110,22 @@ class Model(ModelDesc):
...
@@ -105,22 +110,22 @@ class Model(ModelDesc):
with
tf
.
variable_scope
(
'target'
):
with
tf
.
variable_scope
(
'target'
):
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
# NxA
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
# NxA
# DQN
if
METHOD
!=
'Double'
:
#best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
# DQN
best_v
=
tf
.
reduce_max
(
targetQ_predict_value
,
1
)
# N,
# Double-DQN
else
:
tf
.
get_variable_scope
()
.
reuse_variables
()
# Double-DQN
next_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
tf
.
get_variable_scope
()
.
reuse_variables
()
self
.
greedy_choice
=
tf
.
argmax
(
next_predict_value
,
1
)
# N,
next_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
predict_onehot
=
tf
.
one_hot
(
self
.
greedy_choice
,
NUM_ACTIONS
,
1.0
,
0.0
)
self
.
greedy_choice
=
tf
.
argmax
(
next_predict_value
,
1
)
# N,
best_v
=
tf
.
reduce_sum
(
targetQ_predict_value
*
predict_onehot
,
1
)
predict_onehot
=
tf
.
one_hot
(
self
.
greedy_choice
,
NUM_ACTIONS
,
1.0
,
0.0
)
best_v
=
tf
.
reduce_sum
(
targetQ_predict_value
*
predict_onehot
,
1
)
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
cost
=
symbf
.
huber_loss
(
target
-
pred_action_value
)
self
.
cost
=
symbf
.
huber_loss
(
target
-
pred_action_value
,
name
=
'cost'
)
summary
.
add_param_summary
([(
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
summary
.
add_param_summary
([(
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
])
])
# monitor all W
(
'fc.*/W'
,
[
'histogram'
,
'rms'
])
])
# monitor all W
self
.
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cost'
)
def
update_target_param
(
self
):
def
update_target_param
(
self
):
vars
=
tf
.
trainable_variables
()
vars
=
tf
.
trainable_variables
()
...
@@ -134,8 +139,7 @@ class Model(ModelDesc):
...
@@ -134,8 +139,7 @@ class Model(ModelDesc):
return
tf
.
group
(
*
ops
,
name
=
'update_target_network'
)
return
tf
.
group
(
*
ops
,
name
=
'update_target_network'
)
def
get_gradient_processor
(
self
):
def
get_gradient_processor
(
self
):
return
[
MapGradient
(
lambda
grad
:
\
return
[
MapGradient
(
lambda
grad
:
tf
.
clip_by_global_norm
([
grad
],
5
)[
0
][
0
]),
tf
.
clip_by_global_norm
([
grad
],
5
)[
0
][
0
]),
SummaryGradient
()]
SummaryGradient
()]
def
get_config
():
def
get_config
():
...
@@ -143,7 +147,7 @@ def get_config():
...
@@ -143,7 +147,7 @@ def get_config():
M
=
Model
()
M
=
Model
()
dataset_train
=
ExpReplay
(
dataset_train
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'
fct/output
'
]),
predictor_io_names
=
([
'state'
],
[
'
Qvalue
'
]),
player
=
get_player
(
train
=
True
),
player
=
get_player
(
train
=
True
),
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
...
@@ -167,7 +171,7 @@ def get_config():
...
@@ -167,7 +171,7 @@ def get_config():
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
RunOp
(
lambda
:
M
.
update_target_param
()),
RunOp
(
lambda
:
M
.
update_target_param
()),
dataset_train
,
dataset_train
,
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'
fct/output
'
]),
3
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'
Qvalue
'
]),
3
),
#HumanHyperParamSetter('learning_rate', 'hyper.txt'),
#HumanHyperParamSetter('learning_rate', 'hyper.txt'),
#HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
#HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
]),
]),
...
@@ -197,7 +201,7 @@ if __name__ == '__main__':
...
@@ -197,7 +201,7 @@ if __name__ == '__main__':
model
=
Model
(),
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
session_init
=
SaverRestore
(
args
.
load
),
input_var_names
=
[
'state'
],
input_var_names
=
[
'state'
],
output_var_names
=
[
'
fct/output:0
'
])
output_var_names
=
[
'
Qvalue
'
])
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
play_model
(
cfg
)
play_model
(
cfg
)
elif
args
.
task
==
'eval'
:
elif
args
.
task
==
'eval'
:
...
...
examples/Atari2600/README.md
View file @
d3167ba3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
[
video demo
](
https://youtu.be/o21mddZtE5Y
)
[
video demo
](
https://youtu.be/o21mddZtE5Y
)
Reproduce the following reinforcement learning
method
s:
Reproduce the following reinforcement learning
paper
s:
+
Nature-DQN in:
+
Nature-DQN in:
[
Human-level Control Through Deep Reinforcement Learning
](
http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html
)
[
Human-level Control Through Deep Reinforcement Learning
](
http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html
)
...
@@ -10,6 +10,8 @@ Reproduce the following reinforcement learning methods:
...
@@ -10,6 +10,8 @@ Reproduce the following reinforcement learning methods:
+
Double-DQN in:
+
Double-DQN in:
[
Deep Reinforcement Learning with Double Q-learning
](
http://arxiv.org/abs/1509.06461
)
[
Deep Reinforcement Learning with Double Q-learning
](
http://arxiv.org/abs/1509.06461
)
+
Dueling-DQN in:
[
Dueling Network Architectures for Deep Reinforcement Learning
](
https://arxiv.org/abs/1511.06581
)
+
A3C in
[
Asynchronous Methods for Deep Reinforcement Learning
](
http://arxiv.org/abs/1602.01783
)
. (I
+
A3C in
[
Asynchronous Methods for Deep Reinforcement Learning
](
http://arxiv.org/abs/1602.01783
)
. (I
used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".)
used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment