Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b61a2d89
Commit
b61a2d89
authored
May 30, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve dqn
parent
77755875
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
22 deletions
+10
-22
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+8
-18
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+2
-4
No files found.
examples/DeepQNetwork/DQN.py
View file @
b61a2d89
...
@@ -34,10 +34,6 @@ ACTION_REPEAT = 4
...
@@ -34,10 +34,6 @@ ACTION_REPEAT = 4
GAMMA
=
0.99
GAMMA
=
0.99
INIT_EXPLORATION
=
1
EXPLORATION_EPOCH_ANNEAL
=
0.01
END_EXPLORATION
=
0.1
MEMORY_SIZE
=
1e6
MEMORY_SIZE
=
1e6
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE
=
5e4
INIT_MEMORY_SIZE
=
5e4
...
@@ -73,18 +69,10 @@ class Model(DQNModel):
...
@@ -73,18 +69,10 @@ class Model(DQNModel):
with
argscope
(
Conv2D
,
nl
=
PReLU
.
symbolic_function
,
use_bias
=
True
),
\
with
argscope
(
Conv2D
,
nl
=
PReLU
.
symbolic_function
,
use_bias
=
True
),
\
argscope
(
LeakyReLU
,
alpha
=
0.01
):
argscope
(
LeakyReLU
,
alpha
=
0.01
):
l
=
(
LinearWrap
(
image
)
l
=
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
,
out_channel
=
32
,
kernel_shape
=
5
)
.
MaxPooling
(
'pool0'
,
2
)
.
Conv2D
(
'conv1'
,
out_channel
=
32
,
kernel_shape
=
5
)
.
MaxPooling
(
'pool1'
,
2
)
.
Conv2D
(
'conv2'
,
out_channel
=
64
,
kernel_shape
=
4
)
.
MaxPooling
(
'pool2'
,
2
)
.
Conv2D
(
'conv3'
,
out_channel
=
64
,
kernel_shape
=
3
)
# the original arch is 2x faster
# the original arch is 2x faster
#
.Conv2D('conv0', out_channel=32, kernel_shape=8, stride=4)
.
Conv2D
(
'conv0'
,
out_channel
=
32
,
kernel_shape
=
8
,
stride
=
4
)
#
.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
.
Conv2D
(
'conv1'
,
out_channel
=
64
,
kernel_shape
=
4
,
stride
=
2
)
#
.Conv2D('conv2', out_channel=64, kernel_shape=3)
.
Conv2D
(
'conv2'
,
out_channel
=
64
,
kernel_shape
=
3
)
.
FullyConnected
(
'fc0'
,
512
,
nl
=
LeakyReLU
)())
.
FullyConnected
(
'fc0'
,
512
,
nl
=
LeakyReLU
)())
if
self
.
method
!=
'Dueling'
:
if
self
.
method
!=
'Dueling'
:
...
@@ -108,9 +96,7 @@ def get_config():
...
@@ -108,9 +96,7 @@ def get_config():
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
exploration
=
INIT_EXPLORATION
,
init_exploration
=
1.0
,
end_exploration
=
END_EXPLORATION
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
update_frequency
=
4
,
update_frequency
=
4
,
history_len
=
FRAME_HISTORY
history_len
=
FRAME_HISTORY
)
)
...
@@ -121,6 +107,10 @@ def get_config():
...
@@ -121,6 +107,10 @@ def get_config():
ModelSaver
(),
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
ScheduledHyperParamSetter
(
ObjAttrParam
(
expreplay
,
'exploration'
),
[(
0
,
1
),
(
100
,
0.1
),
(
200
,
0.01
)],
interp
=
'linear'
),
RunOp
(
DQNModel
.
update_target_param
),
RunOp
(
DQNModel
.
update_target_param
),
expreplay
,
expreplay
,
PeriodicTrigger
(
Evaluator
(
PeriodicTrigger
(
Evaluator
(
...
...
examples/DeepQNetwork/expreplay.py
View file @
b61a2d89
...
@@ -123,7 +123,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -123,7 +123,7 @@ class ExpReplay(DataFlow, Callback):
state_shape
,
state_shape
,
batch_size
,
batch_size
,
memory_size
,
init_memory_size
,
memory_size
,
init_memory_size
,
exploration
,
end_exploration
,
exploration_epoch_anneal
,
init_exploration
,
update_frequency
,
history_len
):
update_frequency
,
history_len
):
"""
"""
Args:
Args:
...
@@ -140,6 +140,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -140,6 +140,7 @@ class ExpReplay(DataFlow, Callback):
for
k
,
v
in
locals
()
.
items
():
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
self
.
exploration
=
init_exploration
self
.
num_actions
=
player
.
get_action_space
()
.
num_actions
()
self
.
num_actions
=
player
.
get_action_space
()
.
num_actions
()
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
...
@@ -245,9 +246,6 @@ class ExpReplay(DataFlow, Callback):
...
@@ -245,9 +246,6 @@ class ExpReplay(DataFlow, Callback):
self
.
_simulator_th
.
start
()
self
.
_simulator_th
.
start
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
exploration
>
self
.
end_exploration
:
self
.
exploration
-=
self
.
exploration_epoch_anneal
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
exploration
))
# log player statistics
# log player statistics
stats
=
self
.
player
.
stats
stats
=
self
.
player
.
stats
for
k
,
v
in
six
.
iteritems
(
stats
):
for
k
,
v
in
six
.
iteritems
(
stats
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment