Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fefdcfb1
Commit
fefdcfb1
authored
Jul 16, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use get_predict_func for DQN
parent
bbc17cb1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
7 deletions
+10
-7
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-4
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+9
-3
No files found.
examples/Atari2600/DQN.py
View file @
fefdcfb1
...
...
@@ -134,15 +134,12 @@ class Model(ModelDesc):
tf
.
clip_by_global_norm
([
grad
],
5
)[
0
][
0
]),
SummaryGradient
()]
def
predictor
(
self
,
state
):
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
def
get_config
():
logger
.
auto_set_dir
()
M
=
Model
()
dataset_train
=
ExpReplay
(
predictor
=
M
.
predictor
,
predictor
_io_names
=
([
'state'
],
[
'fct/output'
])
,
player
=
get_player
(
train
=
True
),
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
...
...
tensorpack/RL/expreplay.py
View file @
fefdcfb1
...
...
@@ -29,7 +29,7 @@ class ExpReplay(DataFlow, Callback):
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
def
__init__
(
self
,
predictor
,
predictor
_io_names
,
player
,
batch_size
=
32
,
memory_size
=
1e6
,
...
...
@@ -64,6 +64,7 @@ class ExpReplay(DataFlow, Callback):
self
.
mem
=
deque
(
maxlen
=
memory_size
)
self
.
rng
=
get_rng
(
self
)
self
.
_init_memory_flag
=
threading
.
Event
()
# tell if memory has been initialized
self
.
_predictor_io_names
=
predictor_io_names
def
_init_memory
(
self
):
logger
.
info
(
"Populating replay memory..."
)
...
...
@@ -90,6 +91,7 @@ class ExpReplay(DataFlow, Callback):
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
# build a history state
# XXX assume a state can be representated by one tensor
ss
=
[
old_s
]
isOver
=
False
...
...
@@ -103,7 +105,9 @@ class ExpReplay(DataFlow, Callback):
ss
.
append
(
hist_exp
.
state
)
ss
.
reverse
()
ss
=
np
.
concatenate
(
ss
,
axis
=
2
)
act
=
np
.
argmax
(
self
.
predictor
(
ss
))
# XXX assume batched network
q_values
=
self
.
predictor
([[
ss
]])[
0
][
0
]
act
=
np
.
argmax
(
q_values
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
if
self
.
reward_clip
:
reward
=
np
.
clip
(
reward
,
self
.
reward_clip
[
0
],
self
.
reward_clip
[
1
])
...
...
@@ -171,6 +175,9 @@ class ExpReplay(DataFlow, Callback):
isOver
=
np
.
array
([
e
[
4
]
for
e
in
batch_exp
],
dtype
=
'bool'
)
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
def
_setup_graph
(
self
):
self
.
predictor
=
self
.
trainer
.
get_predict_func
(
*
self
.
_predictor_io_names
)
# Callback-related:
def
_before_train
(
self
):
# spawn a separate thread to run policy, can speed up 1.3x
...
...
@@ -204,7 +211,6 @@ if __name__ == '__main__':
from
.atari
import
AtariPlayer
import
sys
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
.
initialized
=
False
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
E
=
ExpReplay
(
predictor
,
player
=
player
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment