Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
31cfcadf
Commit
31cfcadf
authored
Apr 22, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[DQN] split the environment runner from expreplay
parent
6d4a77c7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
57 deletions
+106
-57
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+4
-1
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+102
-56
No files found.
examples/DeepQNetwork/DQN.py
View file @
31cfcadf
...
@@ -108,12 +108,15 @@ def get_config(model):
...
@@ -108,12 +108,15 @@ def get_config(model):
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
init_exploration
=
1.0
,
update_frequency
=
UPDATE_FREQ
,
update_frequency
=
UPDATE_FREQ
,
history_len
=
FRAME_HISTORY
,
history_len
=
FRAME_HISTORY
,
state_dtype
=
model
.
state_dtype
.
as_numpy_dtype
state_dtype
=
model
.
state_dtype
.
as_numpy_dtype
)
)
# Set to other values if you need a different initial exploration
# (e.g., # if you're resuming a training half-way)
# expreplay.exploration = 1.0
return
TrainConfig
(
return
TrainConfig
(
data
=
QueueInput
(
expreplay
),
data
=
QueueInput
(
expreplay
),
model
=
model
,
model
=
model
,
...
...
examples/DeepQNetwork/expreplay.py
View file @
31cfcadf
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
import
threading
import
threading
from
collections
import
deque
,
namedtuple
from
collections
import
namedtuple
from
six.moves
import
range
from
six.moves
import
range
from
tensorpack.callbacks.base
import
Callback
from
tensorpack.callbacks.base
import
Callback
...
@@ -30,7 +30,7 @@ class ReplayMemory(object):
...
@@ -30,7 +30,7 @@ class ReplayMemory(object):
self
.
max_size
=
int
(
max_size
)
self
.
max_size
=
int
(
max_size
)
self
.
state_shape
=
state_shape
self
.
state_shape
=
state_shape
assert
len
(
state_shape
)
in
[
1
,
2
,
3
],
state_shape
assert
len
(
state_shape
)
in
[
1
,
2
,
3
],
state_shape
self
.
_output_shape
=
self
.
state_shape
+
(
history_len
+
1
,
)
#
self._output_shape = self.state_shape + (history_len + 1, )
self
.
history_len
=
int
(
history_len
)
self
.
history_len
=
int
(
history_len
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -45,7 +45,6 @@ class ReplayMemory(object):
...
@@ -45,7 +45,6 @@ class ReplayMemory(object):
self
.
_curr_size
=
0
self
.
_curr_size
=
0
self
.
_curr_pos
=
0
self
.
_curr_pos
=
0
self
.
_hist
=
deque
(
maxlen
=
history_len
-
1
)
def
append
(
self
,
exp
):
def
append
(
self
,
exp
):
"""
"""
...
@@ -59,17 +58,6 @@ class ReplayMemory(object):
...
@@ -59,17 +58,6 @@ class ReplayMemory(object):
else
:
else
:
self
.
_assign
(
self
.
_curr_pos
,
exp
)
self
.
_assign
(
self
.
_curr_pos
,
exp
)
self
.
_curr_pos
=
(
self
.
_curr_pos
+
1
)
%
self
.
max_size
self
.
_curr_pos
=
(
self
.
_curr_pos
+
1
)
%
self
.
max_size
if
exp
.
isOver
:
self
.
_hist
.
clear
()
else
:
self
.
_hist
.
append
(
exp
)
def
recent_state
(
self
):
""" return a list of ``hist_len-1`` elements, each of shape ``self.state_shape`` """
lst
=
list
(
self
.
_hist
)
states
=
[
np
.
zeros
(
self
.
state_shape
,
dtype
=
self
.
dtype
)]
*
(
self
.
_hist
.
maxlen
-
len
(
lst
))
states
.
extend
([
k
.
state
for
k
in
lst
])
return
states
def
sample
(
self
,
idx
):
def
sample
(
self
,
idx
):
""" return a tuple of (s,r,a,o),
""" return a tuple of (s,r,a,o),
...
@@ -118,6 +106,92 @@ class ReplayMemory(object):
...
@@ -118,6 +106,92 @@ class ReplayMemory(object):
self
.
isOver
[
pos
]
=
exp
.
isOver
self
.
isOver
[
pos
]
=
exp
.
isOver
class
EnvRunner
(
object
):
"""
A class which is responsible for
stepping the environment with epsilon-greedy,
and fill the results to experience replay buffer.
"""
def
__init__
(
self
,
player
,
predictor
,
memory
,
history_len
):
"""
Args:
player (gym.Env)
predictor (callable): the model forward function which takes a
state and returns the prediction.
memory (ReplayMemory): the replay memory to store experience to.
history_len (int):
"""
self
.
player
=
player
self
.
num_actions
=
player
.
action_space
.
n
self
.
predictor
=
predictor
self
.
memory
=
memory
self
.
state_shape
=
memory
.
state_shape
self
.
dtype
=
memory
.
dtype
self
.
history_len
=
history_len
self
.
_current_episode
=
[]
self
.
_current_ob
=
player
.
reset
()
self
.
_current_game_score
=
StatCounter
()
# store per-step reward
self
.
_player_scores
=
StatCounter
()
# store per-game total score
self
.
rng
=
get_rng
(
self
)
def
step
(
self
,
exploration
):
"""
Run the environment for one step.
If the episode ends, store the entire episode to the replay memory.
"""
old_s
=
self
.
_current_ob
if
self
.
rng
.
rand
()
<=
exploration
:
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
history
=
self
.
recent_state
()
history
.
append
(
old_s
)
history
=
np
.
stack
(
history
,
axis
=-
1
)
# state_shape + (Hist,)
# assume batched network
history
=
np
.
expand_dims
(
history
,
axis
=
0
)
q_values
=
self
.
predictor
(
history
)[
0
][
0
]
# this is the bottleneck
act
=
np
.
argmax
(
q_values
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_game_score
.
feed
(
reward
)
self
.
_current_episode
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
if
isOver
:
flush_experience
=
True
if
'ale.lives'
in
info
:
# if running Atari, do something special
if
info
[
'ale.lives'
]
!=
0
:
# only record score and flush experience
# when a whole game is over (not when an episode is over)
flush_experience
=
False
self
.
player
.
reset
()
if
flush_experience
:
self
.
_player_scores
.
feed
(
self
.
_current_game_score
.
sum
)
self
.
_current_game_score
.
reset
()
# TODO lock here if having multiple runner
for
exp
in
self
.
_current_episode
:
self
.
memory
.
append
(
exp
)
self
.
_current_episode
.
clear
()
def
recent_state
(
self
):
"""
Get the recent state (with stacked history) of the environment.
Returns:
a list of ``hist_len-1`` elements, each of shape ``self.state_shape``
"""
expected_len
=
self
.
history_len
-
1
if
len
(
self
.
_current_episode
)
>=
expected_len
:
return
[
k
.
state
for
k
in
self
.
_current_episode
[
-
expected_len
:]]
else
:
states
=
[
np
.
zeros
(
self
.
state_shape
,
dtype
=
self
.
dtype
)]
*
(
expected_len
-
len
(
self
.
_current_episode
))
states
.
extend
([
k
.
state
for
k
in
self
.
_current_episode
])
return
states
class
ExpReplay
(
DataFlow
,
Callback
):
class
ExpReplay
(
DataFlow
,
Callback
):
"""
"""
Implement experience replay in the paper
Implement experience replay in the paper
...
@@ -137,7 +211,6 @@ class ExpReplay(DataFlow, Callback):
...
@@ -137,7 +211,6 @@ class ExpReplay(DataFlow, Callback):
state_shape
,
state_shape
,
batch_size
,
batch_size
,
memory_size
,
init_memory_size
,
memory_size
,
init_memory_size
,
init_exploration
,
update_frequency
,
history_len
,
update_frequency
,
history_len
,
state_dtype
=
'uint8'
):
state_dtype
=
'uint8'
):
"""
"""
...
@@ -146,10 +219,14 @@ class ExpReplay(DataFlow, Callback):
...
@@ -146,10 +219,14 @@ class ExpReplay(DataFlow, Callback):
predict Q value from state.
predict Q value from state.
player (gym.Env): the player.
player (gym.Env): the player.
state_shape (tuple):
state_shape (tuple):
history_len (int): length of history frames to concat. Zero-filled
batch_size (int):
initial frames.
memory_size (int):
init_memory_size (int):
update_frequency (int): number of new transitions to add to memory
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
after sampling a batch of transitions for training.
history_len (int): length of history frames to concat. Zero-filled
initial frames.
state_dtype (str):
"""
"""
assert
len
(
state_shape
)
in
[
1
,
2
,
3
],
state_shape
assert
len
(
state_shape
)
in
[
1
,
2
,
3
],
state_shape
init_memory_size
=
int
(
init_memory_size
)
init_memory_size
=
int
(
init_memory_size
)
...
@@ -157,24 +234,21 @@ class ExpReplay(DataFlow, Callback):
...
@@ -157,24 +234,21 @@ class ExpReplay(DataFlow, Callback):
for
k
,
v
in
locals
()
.
items
():
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
self
.
exploration
=
init_
exploration
self
.
exploration
=
1.0
# default initial
exploration
self
.
num_actions
=
player
.
action_space
.
n
self
.
num_actions
=
player
.
action_space
.
n
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
self
.
_init_memory_flag
=
threading
.
Event
()
# tell if memory has been initialized
self
.
_init_memory_flag
=
threading
.
Event
()
# tell if memory has been initialized
self
.
mem
=
ReplayMemory
(
memory_size
,
state_shape
,
history_len
)
self
.
mem
=
ReplayMemory
(
memory_size
,
state_shape
,
self
.
history_len
,
dtype
=
state_dtype
)
self
.
_current_ob
=
self
.
player
.
reset
()
self
.
_player_scores
=
StatCounter
()
self
.
_current_game_score
=
StatCounter
()
def
_init_memory
(
self
):
def
_init_memory
(
self
):
logger
.
info
(
"Populating replay memory with epsilon={} ..."
.
format
(
self
.
exploration
))
logger
.
info
(
"Populating replay memory with epsilon={} ..."
.
format
(
self
.
exploration
))
with
get_tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
with
get_tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
self
.
_populate_exp
(
)
self
.
env_runner
.
step
(
self
.
exploration
)
pbar
.
update
()
pbar
.
update
()
self
.
_init_memory_flag
.
set
()
self
.
_init_memory_flag
.
set
()
...
@@ -183,42 +257,13 @@ class ExpReplay(DataFlow, Callback):
...
@@ -183,42 +257,13 @@ class ExpReplay(DataFlow, Callback):
from
copy
import
deepcopy
from
copy
import
deepcopy
with
get_tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
with
get_tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
5
:
while
len
(
self
.
mem
)
<
5
:
self
.
_populate_exp
(
)
self
.
env_runner
.
step
(
self
.
exploration
)
pbar
.
update
()
pbar
.
update
()
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
self
.
mem
.
append
(
deepcopy
(
self
.
mem
.
_hist
[
0
]))
self
.
mem
.
append
(
deepcopy
(
self
.
mem
.
_hist
[
0
]))
pbar
.
update
()
pbar
.
update
()
self
.
_init_memory_flag
.
set
()
self
.
_init_memory_flag
.
set
()
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
old_s
=
self
.
_current_ob
if
self
.
rng
.
rand
()
<=
self
.
exploration
or
(
len
(
self
.
mem
)
<=
self
.
history_len
):
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
# build a history state
history
=
self
.
mem
.
recent_state
()
history
.
append
(
old_s
)
history
=
np
.
stack
(
history
,
axis
=-
1
)
# state_shape + (Hist,)
history
=
np
.
expand_dims
(
history
,
axis
=
0
)
# assume batched network
q_values
=
self
.
predictor
(
history
)[
0
][
0
]
# this is the bottleneck
act
=
np
.
argmax
(
q_values
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_game_score
.
feed
(
reward
)
if
isOver
:
if
'ale.lives'
in
info
:
# if running Atari, do something special for logging:
if
info
[
'ale.lives'
]
==
0
:
# only record score when a whole game is over (not when an episode is over)
self
.
_player_scores
.
feed
(
self
.
_current_game_score
.
sum
)
self
.
_current_game_score
.
reset
()
else
:
self
.
_player_scores
.
feed
(
self
.
_current_game_score
.
sum
)
self
.
_current_game_score
.
reset
()
self
.
player
.
reset
()
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
_debug_sample
(
self
,
sample
):
def
_debug_sample
(
self
,
sample
):
import
cv2
import
cv2
...
@@ -257,17 +302,18 @@ class ExpReplay(DataFlow, Callback):
...
@@ -257,17 +302,18 @@ class ExpReplay(DataFlow, Callback):
# execute 4 new actions into memory, after each batch update
# execute 4 new actions into memory, after each batch update
for
_
in
range
(
self
.
update_frequency
):
for
_
in
range
(
self
.
update_frequency
):
self
.
_populate_exp
(
)
self
.
env_runner
.
step
(
self
.
exploration
)
# Callback methods:
# Callback methods:
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
predictor
=
self
.
trainer
.
get_predictor
(
*
self
.
predictor_io_names
)
predictor
=
self
.
trainer
.
get_predictor
(
*
self
.
predictor_io_names
)
self
.
env_runner
=
EnvRunner
(
self
.
player
,
predictor
,
self
.
mem
,
self
.
history_len
)
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_init_memory
()
self
.
_init_memory
()
def
_trigger
(
self
):
def
_trigger
(
self
):
v
=
self
.
_player_scores
v
=
self
.
env_runner
.
_player_scores
try
:
try
:
mean
,
max
=
v
.
average
,
v
.
max
mean
,
max
=
v
.
average
,
v
.
max
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/mean_score'
,
mean
)
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/mean_score'
,
mean
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment