Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b6df5567
Commit
b6df5567
authored
Feb 16, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move expreplay out of RL.
parent
63004976
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
36 deletions
+40
-36
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+8
-6
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+32
-30
No files found.
examples/DeepQNetwork/DQN.py
View file @
b6df5567
...
...
@@ -26,6 +26,7 @@ from tensorpack.RL import *
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
atari
import
AtariPlayer
from
expreplay
import
ExpReplay
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
...
...
@@ -160,7 +161,7 @@ def get_config():
logger
.
auto_set_dir
()
M
=
Model
()
dataset_train
=
ExpReplay
(
expreplay
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
player
=
get_player
(
train
=
True
),
batch_size
=
BATCH_SIZE
,
...
...
@@ -174,21 +175,22 @@ def get_config():
history_len
=
FRAME_HISTORY
)
return
TrainConfig
(
dataflow
=
dataset_train
,
dataflow
=
expreplay
,
callbacks
=
[
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
RunOp
(
lambda
:
M
.
update_target_param
()),
dataset_train
,
expreplay
,
StartProcOrThread
(
expreplay
.
get_simulator_thread
()),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'Qvalue'
]),
3
),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(
dataset_train
, 'exploration'), 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(
expreplay
, 'exploration'), 'hyper.txt'),
],
# save memory for multi-thread evaluator
session_config
=
get_default_sess_config
(
0.6
),
model
=
M
,
steps_per_epoch
=
STEP_PER_EPOCH
,
# run the simulator on a separate GPU if available
predict_tower
=
[
1
]
if
get_nr_gpu
()
>
1
else
[
0
],
)
...
...
tensorpack/RL
/expreplay.py
→
examples/DeepQNetwork
/expreplay.py
View file @
b6df5567
...
...
@@ -9,10 +9,10 @@ import threading
import
six
from
six.moves
import
queue
from
.
.dataflow
import
DataFlow
from
.
.utils
import
logger
,
get_tqdm
,
get_rng
from
.
.utils.concurrency
import
LoopThread
from
.
.callbacks.base
import
Callback
from
tensorpack
.dataflow
import
DataFlow
from
tensorpack
.utils
import
logger
,
get_tqdm
,
get_rng
from
tensorpack
.utils.concurrency
import
LoopThread
from
tensorpack
.callbacks.base
import
Callback
__all__
=
[
'ExpReplay'
]
...
...
@@ -66,17 +66,28 @@ class ExpReplay(DataFlow, Callback):
self
.
mem
=
deque
(
maxlen
=
int
(
memory_size
))
self
.
rng
=
get_rng
(
self
)
self
.
_init_memory_flag
=
threading
.
Event
()
# tell if memory has been initialized
self
.
_predictor_io_names
=
predictor_io_names
# TODO just use a semaphore?
# a queue to receive notifications to populate memory
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
5
)
def
get_simulator_thread
(
self
):
# spawn a separate thread to run policy, can speed up 1.3x
def
populate_job_func
():
self
.
_populate_job_queue
.
get
()
with
self
.
trainer
.
sess
.
as_default
():
for
_
in
range
(
self
.
update_frequency
):
self
.
_populate_exp
()
th
=
LoopThread
(
populate_job_func
,
pausable
=
False
)
th
.
name
=
"SimulatorThread"
return
th
def
_init_memory
(
self
):
logger
.
info
(
"Populating replay memory
..."
)
logger
.
info
(
"Populating replay memory
with epsilon={} ..."
.
format
(
self
.
exploration
)
)
# fill some for the history
old_exploration
=
self
.
exploration
self
.
exploration
=
1
for
k
in
range
(
self
.
history_len
):
self
.
_populate_exp
()
self
.
exploration
=
old_exploration
with
get_tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
...
...
@@ -95,7 +106,7 @@ class ExpReplay(DataFlow, Callback):
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
# build a history state
#
XXX
assume a state can be representated by one tensor
# assume a state can be representated by one tensor
ss
=
[
old_s
]
isOver
=
False
...
...
@@ -104,12 +115,13 @@ class ExpReplay(DataFlow, Callback):
if
hist_exp
.
isOver
:
isOver
=
True
if
isOver
:
# fill the beginning of an episode with zeros
ss
.
append
(
np
.
zeros_like
(
ss
[
0
]))
else
:
ss
.
append
(
hist_exp
.
state
)
ss
.
reverse
()
ss
=
np
.
concatenate
(
ss
,
axis
=
2
)
#
XXX
assume batched network
# assume batched network
q_values
=
self
.
predictor
([[
ss
]])[
0
][
0
]
act
=
np
.
argmax
(
q_values
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
...
...
@@ -118,8 +130,9 @@ class ExpReplay(DataFlow, Callback):
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
get_data
(
self
):
# wait for memory to be initialized
self
.
_init_memory_flag
.
wait
()
# new s is considered useless if isOver==True
while
True
:
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
...
...
@@ -140,6 +153,7 @@ class ExpReplay(DataFlow, Callback):
yield
self
.
_process_batch
(
batch_exp
)
self
.
_populate_job_queue
.
put
(
1
)
# new state is considered useless if isOver==True
def
_sample_one
(
self
):
""" return the transition tuple for
[idx, idx+history_len) -> [idx+1, idx+1+history_len)
...
...
@@ -173,29 +187,17 @@ class ExpReplay(DataFlow, Callback):
return
(
state
,
next_state
,
reward
,
action
,
isOver
)
def
_process_batch
(
self
,
batch_exp
):
state
=
np
.
array
([
e
[
0
]
for
e
in
batch_exp
])
next_state
=
np
.
array
([
e
[
1
]
for
e
in
batch_exp
])
reward
=
np
.
array
([
e
[
2
]
for
e
in
batch_exp
])
action
=
np
.
array
([
e
[
3
]
for
e
in
batch_exp
],
dtype
=
'int8'
)
isOver
=
np
.
array
([
e
[
4
]
for
e
in
batch_exp
],
dtype
=
'bool'
)
state
=
np
.
a
sa
rray
([
e
[
0
]
for
e
in
batch_exp
])
next_state
=
np
.
a
sa
rray
([
e
[
1
]
for
e
in
batch_exp
])
reward
=
np
.
a
sa
rray
([
e
[
2
]
for
e
in
batch_exp
])
action
=
np
.
a
sa
rray
([
e
[
3
]
for
e
in
batch_exp
],
dtype
=
'int8'
)
isOver
=
np
.
a
sa
rray
([
e
[
4
]
for
e
in
batch_exp
],
dtype
=
'bool'
)
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
def
_setup_graph
(
self
):
self
.
predictor
=
self
.
trainer
.
get_predict_func
(
*
self
.
_
predictor_io_names
)
self
.
predictor
=
self
.
trainer
.
get_predict_func
(
*
self
.
predictor_io_names
)
# Callback-related:
def
_before_train
(
self
):
# spawn a separate thread to run policy, can speed up 1.3x
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
1
)
def
populate_job_func
():
self
.
_populate_job_queue
.
get
()
with
self
.
trainer
.
sess
.
as_default
():
for
_
in
range
(
self
.
update_frequency
):
self
.
_populate_exp
()
self
.
_populate_job_th
=
LoopThread
(
populate_job_func
,
False
)
self
.
_populate_job_th
.
start
()
self
.
_init_memory
()
def
_trigger_epoch
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment