Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ba435f10
Commit
ba435f10
authored
Jun 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update DQN hyperparam
parent
4071cbec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
7 deletions
+10
-7
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+10
-7
No files found.
examples/Atari2600/DQN.py
View file @
ba435f10
...
@@ -20,11 +20,12 @@ from tensorpack.RL import *
...
@@ -20,11 +20,12 @@ from tensorpack.RL import *
import
common
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
BATCH_SIZE
=
32
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
ACTION_REPEAT
=
4
ACTION_REPEAT
=
4
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
HEIGHT_RANGE
=
(
None
,
None
)
#HEIGHT_RANGE = (36, 204) # for breakout
#HEIGHT_RANGE = (28, -8) # for pong
#HEIGHT_RANGE = (28, -8) # for pong
CHANNEL
=
FRAME_HISTORY
CHANNEL
=
FRAME_HISTORY
...
@@ -32,7 +33,7 @@ IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
...
@@ -32,7 +33,7 @@ IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
GAMMA
=
0.99
GAMMA
=
0.99
INIT_EXPLORATION
=
1
INIT_EXPLORATION
=
1
EXPLORATION_EPOCH_ANNEAL
=
0.0
08
EXPLORATION_EPOCH_ANNEAL
=
0.0
1
END_EXPLORATION
=
0.1
END_EXPLORATION
=
0.1
MEMORY_SIZE
=
1e6
MEMORY_SIZE
=
1e6
...
@@ -133,7 +134,7 @@ class Model(ModelDesc):
...
@@ -133,7 +134,7 @@ class Model(ModelDesc):
SummaryGradient
()]
SummaryGradient
()]
def
predictor
(
self
,
state
):
def
predictor
(
self
,
state
):
# TODO
change to a multitower predictor for speedup
# TODO
use multitower predictor to speed up training
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
def
get_config
():
def
get_config
():
...
@@ -155,7 +156,7 @@ def get_config():
...
@@ -155,7 +156,7 @@ def get_config():
reward_clip
=
(
-
1
,
1
),
reward_clip
=
(
-
1
,
1
),
history_len
=
FRAME_HISTORY
)
history_len
=
FRAME_HISTORY
)
lr
=
tf
.
Variable
(
0.00
04
,
trainable
=
False
,
name
=
'learning_rate'
)
lr
=
tf
.
Variable
(
0.00
1
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
return
TrainConfig
(
...
@@ -164,11 +165,13 @@ def get_config():
...
@@ -164,11 +165,13 @@ def get_config():
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
ModelSaver
(),
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
RunOp
(
lambda
:
M
.
update_target_param
()),
RunOp
(
lambda
:
M
.
update_target_param
()),
dataset_train
,
dataset_train
,
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'fct/output'
]),
2
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'fct/output'
]),
3
),
]),
]),
# save memory for multiprocess evaluator
# save memory for multiprocess evaluator
session_config
=
get_default_sess_config
(
0.6
),
session_config
=
get_default_sess_config
(
0.6
),
...
@@ -205,6 +208,6 @@ if __name__ == '__main__':
...
@@ -205,6 +208,6 @@ if __name__ == '__main__':
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
SimpleTrainer
(
config
)
.
train
()
SimpleTrainer
(
config
)
.
train
()
# TODO test if queue trainer works
#QueueInputTrainer(config).train()
#QueueInputTrainer(config).train()
# TODO test if QueueInput affects learning
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment