Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0870401c
Commit
0870401c
authored
Feb 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
speedup DQN.
parent
cab0c4f3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
156 additions
and
105 deletions
+156
-105
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+19
-10
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+1
-2
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+1
-1
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+130
-91
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+5
-1
No files found.
examples/DeepQNetwork/DQN.py
View file @
0870401c
...
...
@@ -20,7 +20,6 @@ from collections import deque
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.RL
import
*
import
common
...
...
@@ -34,7 +33,6 @@ FRAME_HISTORY = 4
ACTION_REPEAT
=
4
CHANNEL
=
FRAME_HISTORY
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
GAMMA
=
0.99
INIT_EXPLORATION
=
1
...
...
@@ -59,6 +57,7 @@ def get_player(viz=False, train=False):
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_action_space
()
.
num_actions
()
if
not
train
:
pl
=
MapPlayerState
(
pl
,
lambda
im
:
im
[:,
:,
np
.
newaxis
])
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
30000
)
...
...
@@ -73,10 +72,11 @@ class Model(ModelDesc):
if
NUM_ACTIONS
is
None
:
p
=
get_player
()
del
p
return
[
InputDesc
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
return
[
InputDesc
(
tf
.
uint8
,
(
None
,)
+
IMAGE_SIZE
+
(
CHANNEL
+
1
,),
'comb_state'
),
InputDesc
(
tf
.
int64
,
(
None
,),
'action'
),
InputDesc
(
tf
.
float32
,
(
None
,),
'reward'
),
InputDesc
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'next_state'
),
InputDesc
(
tf
.
bool
,
(
None
,),
'isOver'
)]
def
_get_DQN_prediction
(
self
,
image
):
...
...
@@ -108,13 +108,20 @@ class Model(ModelDesc):
return
tf
.
identity
(
Q
,
name
=
'Qvalue'
)
def
_build_graph
(
self
,
inputs
):
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
ctx
=
get_current_tower_context
()
comb_state
,
action
,
reward
,
isOver
=
inputs
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'state'
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
)
if
not
ctx
.
is_training
:
return
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'next_state'
)
action_onehot
=
tf
.
one_hot
(
action
,
NUM_ACTIONS
,
1.0
,
0.0
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
add_moving_summary
(
max_pred_reward
)
summary
.
add_moving_summary
(
max_pred_reward
)
with
tf
.
variable_scope
(
'target'
):
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
# NxA
...
...
@@ -137,7 +144,7 @@ class Model(ModelDesc):
target
-
pred_action_value
),
name
=
'cost'
)
summary
.
add_param_summary
((
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
]))
# monitor all W
add_moving_summary
(
self
.
cost
)
summary
.
add_moving_summary
(
self
.
cost
)
def
update_target_param
(
self
):
vars
=
tf
.
trainable_variables
()
...
...
@@ -164,6 +171,7 @@ def get_config():
expreplay
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
player
=
get_player
(
train
=
True
),
state_shape
=
IMAGE_SIZE
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
...
...
@@ -171,8 +179,9 @@ def get_config():
end_exploration
=
END_EXPLORATION
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
update_frequency
=
4
,
reward_clip
=
(
-
1
,
1
),
history_len
=
FRAME_HISTORY
)
history_len
=
FRAME_HISTORY
,
reward_clip
=
(
-
1
,
1
)
)
return
TrainConfig
(
dataflow
=
expreplay
,
...
...
@@ -215,7 +224,7 @@ if __name__ == '__main__':
if
args
.
task
!=
'train'
:
cfg
=
PredictConfig
(
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'state'
],
output_names
=
[
'Qvalue'
])
if
args
.
task
==
'play'
:
...
...
examples/DeepQNetwork/atari.py
View file @
0870401c
...
...
@@ -106,7 +106,7 @@ class AtariPlayer(RLEnvironment):
def
current_state
(
self
):
"""
:returns: a gray-scale (h, w
, 1
) uint8 image
:returns: a gray-scale (h, w) uint8 image
"""
ret
=
self
.
_grab_raw_image
()
# max-pooled over the last screen
...
...
@@ -119,7 +119,6 @@ class AtariPlayer(RLEnvironment):
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
ret
=
cv2
.
resize
(
ret
,
self
.
image_shape
)
ret
=
np
.
expand_dims
(
ret
,
axis
=
2
)
return
ret
.
astype
(
'uint8'
)
# to save some memory
def
get_action_space
(
self
):
...
...
examples/DeepQNetwork/common.py
View file @
0870401c
...
...
@@ -90,7 +90,7 @@ def eval_with_funcs(predict_funcs, nr_eval):
def
eval_model_multithread
(
cfg
,
nr_eval
):
func
=
get_predict_func
(
cfg
)
func
=
OfflinePredictor
(
cfg
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
...
...
examples/DeepQNetwork/expreplay.py
View file @
0870401c
...
...
@@ -4,10 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
import
copy
from
collections
import
deque
,
namedtuple
import
threading
import
six
from
six.moves
import
queue
from
six.moves
import
queue
,
range
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.utils
import
logger
,
get_tqdm
,
get_rng
...
...
@@ -20,6 +21,89 @@ Experience = namedtuple('Experience',
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
class
ReplayMemory
(
object
):
def
__init__
(
self
,
max_size
,
state_shape
,
history_len
):
self
.
max_size
=
int
(
max_size
)
self
.
state_shape
=
state_shape
self
.
history_len
=
int
(
history_len
)
self
.
state
=
np
.
zeros
((
self
.
max_size
,)
+
state_shape
,
dtype
=
'uint8'
)
self
.
action
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'int32'
)
self
.
reward
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'float32'
)
self
.
isOver
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'bool'
)
self
.
_curr_size
=
0
self
.
_curr_pos
=
0
self
.
_hist
=
deque
(
maxlen
=
history_len
-
1
)
def
append
(
self
,
exp
):
"""
Args:
exp (Experience):
"""
if
self
.
_curr_size
<
self
.
max_size
:
self
.
_assign
(
self
.
_curr_pos
,
exp
)
self
.
_curr_pos
=
(
self
.
_curr_pos
+
1
)
%
self
.
max_size
self
.
_curr_size
+=
1
else
:
self
.
_assign
(
self
.
_curr_pos
,
exp
)
self
.
_curr_pos
=
(
self
.
_curr_pos
+
1
)
%
self
.
max_size
if
exp
.
isOver
:
self
.
_hist
.
clear
()
else
:
self
.
_hist
.
append
(
exp
)
def
recent_state
(
self
):
""" return a list of (hist_len-1,) + STATE_SIZE """
lst
=
list
(
self
.
_hist
)
states
=
[
np
.
zeros
(
self
.
state_shape
,
dtype
=
'uint8'
)]
*
(
self
.
_hist
.
maxlen
-
len
(
lst
))
states
.
extend
([
k
.
state
for
k
in
lst
])
return
states
def
sample
(
self
,
idx
):
""" return a tuple of (s,r,a,o),
where s is of shape STATE_SIZE + (hist_len+1,)"""
idx
=
(
self
.
_curr_pos
+
idx
)
%
self
.
_curr_size
k
=
self
.
history_len
+
1
if
idx
+
k
<=
self
.
_curr_size
:
state
=
self
.
state
[
idx
:
idx
+
k
]
reward
=
self
.
reward
[
idx
:
idx
+
k
]
action
=
self
.
action
[
idx
:
idx
+
k
]
isOver
=
self
.
isOver
[
idx
:
idx
+
k
]
else
:
end
=
idx
+
k
-
self
.
_curr_size
state
=
self
.
_slice
(
self
.
state
,
idx
,
end
)
reward
=
self
.
_slice
(
self
.
reward
,
idx
,
end
)
action
=
self
.
_slice
(
self
.
action
,
idx
,
end
)
isOver
=
self
.
_slice
(
self
.
isOver
,
idx
,
end
)
ret
=
self
.
_pad_sample
(
state
,
reward
,
action
,
isOver
)
return
ret
# the next_state is a different episode if current_state.isOver==True
def
_pad_sample
(
self
,
state
,
reward
,
action
,
isOver
):
for
k
in
range
(
self
.
history_len
-
2
,
-
1
,
-
1
):
if
isOver
[
k
]:
state
=
copy
.
deepcopy
(
state
)
state
[:
k
+
1
]
.
fill
(
0
)
break
state
=
state
.
transpose
(
1
,
2
,
0
)
return
(
state
,
reward
[
-
2
],
action
[
-
2
],
isOver
[
-
2
])
def
_slice
(
self
,
arr
,
start
,
end
):
s1
=
arr
[
start
:]
s2
=
arr
[:
end
]
return
np
.
concatenate
((
s1
,
s2
),
axis
=
0
)
def
__len__
(
self
):
return
self
.
_curr_size
def
_assign
(
self
,
pos
,
exp
):
self
.
state
[
pos
]
=
exp
.
state
self
.
reward
[
pos
]
=
exp
.
reward
self
.
action
[
pos
]
=
exp
.
action
self
.
isOver
[
pos
]
=
exp
.
isOver
class
ExpReplay
(
DataFlow
,
Callback
):
"""
Implement experience replay in the paper
...
...
@@ -36,16 +120,12 @@ class ExpReplay(DataFlow, Callback):
def
__init__
(
self
,
predictor_io_names
,
player
,
batch_size
=
32
,
memory_size
=
1e6
,
init_memory_size
=
50000
,
exploration
=
1
,
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
reward_clip
=
None
,
update_frequency
=
1
,
history_len
=
1
):
state_shape
,
batch_size
,
memory_size
,
init_memory_size
,
exploration
,
end_exploration
,
exploration_epoch_anneal
,
update_frequency
,
history_len
,
reward_clip
=
None
):
"""
Args:
predictor_io_names (tuple of list of str): input/output names to
...
...
@@ -63,7 +143,7 @@ class ExpReplay(DataFlow, Callback):
setattr
(
self
,
k
,
v
)
self
.
num_actions
=
player
.
get_action_space
()
.
num_actions
()
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
mem
=
deque
(
maxlen
=
int
(
memory_size
))
self
.
rng
=
get_rng
(
self
)
self
.
_init_memory_flag
=
threading
.
Event
()
# tell if memory has been initialized
...
...
@@ -71,8 +151,10 @@ class ExpReplay(DataFlow, Callback):
# a queue to receive notifications to populate memory
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
5
)
self
.
mem
=
ReplayMemory
(
memory_size
,
state_shape
,
history_len
)
def
get_simulator_thread
(
self
):
# spawn a separate thread to run policy
, can speed up 1.3x
# spawn a separate thread to run policy
def
populate_job_func
():
self
.
_populate_job_queue
.
get
()
for
_
in
range
(
self
.
update_frequency
):
...
...
@@ -84,10 +166,6 @@ class ExpReplay(DataFlow, Callback):
def
_init_memory
(
self
):
logger
.
info
(
"Populating replay memory with epsilon={} ..."
.
format
(
self
.
exploration
))
# fill some for the history
for
k
in
range
(
self
.
history_len
):
self
.
_populate_exp
()
with
get_tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
self
.
_populate_exp
()
...
...
@@ -96,108 +174,69 @@ class ExpReplay(DataFlow, Callback):
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
# if len(self.mem):
# from copy import deepcopy # quickly fill the memory for debug
#
self.mem.append(deepcopy(self.mem
[0]))
# return
# if len(self.mem)
> 4 and not self._init_memory_flag.is_set()
:
#
from copy import deepcopy # quickly fill the memory for debug
#
self.mem.append(deepcopy(self.mem._hist
[0]))
#
return
old_s
=
self
.
player
.
current_state
()
if
self
.
rng
.
rand
()
<=
self
.
exploration
or
len
(
self
.
mem
)
<
5
:
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
# build a history state
# assume a state can be representated by one tensor
ss
=
[
old_s
]
isOver
=
False
for
k
in
range
(
1
,
self
.
history_len
):
hist_exp
=
self
.
mem
[
-
k
]
if
hist_exp
.
isOver
:
isOver
=
True
if
isOver
:
# fill the beginning of an episode with zeros
ss
.
append
(
np
.
zeros_like
(
ss
[
0
]))
else
:
ss
.
append
(
hist_exp
.
state
)
ss
.
reverse
()
ss
=
np
.
concatenate
(
ss
,
axis
=
2
)
history
=
self
.
mem
.
recent_state
()
history
.
append
(
old_s
)
history
=
np
.
stack
(
history
,
axis
=
2
)
# assume batched network
q_values
=
self
.
predictor
([[
ss
]])[
0
][
0
]
q_values
=
self
.
predictor
([[
history
]])[
0
][
0
]
act
=
np
.
argmax
(
q_values
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
if
self
.
reward_clip
:
reward
=
np
.
clip
(
reward
,
self
.
reward_clip
[
0
],
self
.
reward_clip
[
1
])
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
debug_sample
(
self
,
sample
):
import
cv2
def
view_state
(
comb_state
):
state
=
comb_state
[:,
:,
:
-
1
]
next_state
=
comb_state
[:,
:,
1
:]
r
=
np
.
concatenate
([
state
[:,
:,
k
]
for
k
in
range
(
self
.
history_len
)],
axis
=
1
)
r2
=
np
.
concatenate
([
next_state
[:,
:,
k
]
for
k
in
range
(
self
.
history_len
)],
axis
=
1
)
r
=
np
.
concatenate
([
r
,
r2
],
axis
=
0
)
cv2
.
imshow
(
"state"
,
r
)
cv2
.
waitKey
()
print
(
"Act: "
,
sample
[
2
],
" reward:"
,
sample
[
1
],
" isOver: "
,
sample
[
3
])
if
sample
[
1
]
or
sample
[
3
]:
view_state
(
sample
[
0
])
def
get_data
(
self
):
# wait for memory to be initialized
self
.
_init_memory_flag
.
wait
()
while
True
:
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
# import cv2 # for debug
# def view_state(state, next_state):
# """ for debugging state representation"""
# r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
# r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
# r = np.concatenate([r, r2], axis=0)
# print r.shape
# cv2.imshow("state", r)
# cv2.waitKey()
# exp = batch_exp[0]
# print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
# if exp[2] or exp[4]:
# view_state(exp[0], exp[1])
idx
=
self
.
rng
.
randint
(
self
.
_populate_job_queue
.
maxsize
*
self
.
update_frequency
,
len
(
self
.
mem
)
-
self
.
history_len
-
1
,
size
=
self
.
batch_size
)
batch_exp
=
[
self
.
mem
.
sample
(
i
)
for
i
in
idx
]
yield
self
.
_process_batch
(
batch_exp
)
self
.
_populate_job_queue
.
put
(
1
)
# new state is considered useless if isOver==True
def
_sample_one
(
self
):
""" return the transition tuple for
[idx, idx+history_len) -> [idx+1, idx+1+history_len)
it's the transition from state idx+history_len-1 to state idx+history_len
"""
# look for a state to start with
# when x.isOver==True, (x+1).state is of a different episode
idx
=
self
.
rng
.
randint
(
len
(
self
.
mem
)
-
self
.
history_len
-
1
)
samples
=
[
self
.
mem
[
k
]
for
k
in
range
(
idx
,
idx
+
self
.
history_len
+
1
)]
def
concat
(
idx
):
v
=
[
x
.
state
for
x
in
samples
[
idx
:
idx
+
self
.
history_len
]]
return
np
.
concatenate
(
v
,
axis
=
2
)
state
=
concat
(
0
)
next_state
=
concat
(
1
)
start_mem
=
samples
[
-
2
]
reward
,
action
,
isOver
=
start_mem
.
reward
,
start_mem
.
action
,
start_mem
.
isOver
start_idx
=
self
.
history_len
-
1
# zero-fill state before starting
zero_fill
=
False
for
k
in
range
(
1
,
self
.
history_len
):
if
samples
[
start_idx
-
k
]
.
isOver
:
zero_fill
=
True
if
zero_fill
:
state
[:,
:,
-
k
-
1
]
=
0
if
k
+
2
<=
self
.
history_len
:
next_state
[:,
:,
-
k
-
2
]
=
0
return
(
state
,
next_state
,
reward
,
action
,
isOver
)
def
_process_batch
(
self
,
batch_exp
):
state
=
np
.
asarray
([
e
[
0
]
for
e
in
batch_exp
])
next_state
=
np
.
asarray
([
e
[
1
]
for
e
in
batch_exp
])
reward
=
np
.
asarray
([
e
[
2
]
for
e
in
batch_exp
])
action
=
np
.
asarray
([
e
[
3
]
for
e
in
batch_exp
],
dtype
=
'int8'
)
isOver
=
np
.
asarray
([
e
[
4
]
for
e
in
batch_exp
],
dtype
=
'bool'
)
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
state
=
np
.
asarray
([
e
[
0
]
for
e
in
batch_exp
],
dtype
=
'uint8'
)
reward
=
np
.
asarray
([
e
[
1
]
for
e
in
batch_exp
],
dtype
=
'float32'
)
action
=
np
.
asarray
([
e
[
2
]
for
e
in
batch_exp
],
dtype
=
'int8'
)
isOver
=
np
.
asarray
([
e
[
3
]
for
e
in
batch_exp
],
dtype
=
'bool'
)
return
[
state
,
action
,
reward
,
isOver
]
def
_setup_graph
(
self
):
self
.
predictor
=
self
.
trainer
.
get_predict_func
(
*
self
.
predictor_io_names
)
def
_before_train
(
self
):
self
.
_init_memory
()
# TODO start thread here
def
_trigger_epoch
(
self
):
if
self
.
exploration
>
self
.
end_exploration
:
...
...
tensorpack/utils/concurrency.py
View file @
0870401c
...
...
@@ -116,7 +116,11 @@ class ShareSessionThread(threading.Thread):
@
contextmanager
def
default_sess
(
self
):
with
self
.
_sess
.
as_default
():
if
self
.
_sess
:
with
self
.
_sess
.
as_default
():
yield
else
:
logger
.
warn
(
"ShareSessionThread {} wasn't under a default session!"
.
format
(
self
.
name
))
yield
def
start
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment