Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0870401c
Commit
0870401c
authored
Feb 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
speedup DQN.
parent
cab0c4f3
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
156 additions
and
105 deletions
+156
-105
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+19
-10
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+1
-2
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+1
-1
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+130
-91
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+5
-1
No files found.
examples/DeepQNetwork/DQN.py
View file @
0870401c
...
...
@@ -20,7 +20,6 @@ from collections import deque
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.RL
import
*
import
common
...
...
@@ -34,7 +33,6 @@ FRAME_HISTORY = 4
ACTION_REPEAT
=
4
CHANNEL
=
FRAME_HISTORY
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
GAMMA
=
0.99
INIT_EXPLORATION
=
1
...
...
@@ -59,6 +57,7 @@ def get_player(viz=False, train=False):
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_action_space
()
.
num_actions
()
if
not
train
:
pl
=
MapPlayerState
(
pl
,
lambda
im
:
im
[:,
:,
np
.
newaxis
])
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
30000
)
...
...
@@ -73,10 +72,11 @@ class Model(ModelDesc):
if
NUM_ACTIONS
is
None
:
p
=
get_player
()
del
p
return
[
InputDesc
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
return
[
InputDesc
(
tf
.
uint8
,
(
None
,)
+
IMAGE_SIZE
+
(
CHANNEL
+
1
,),
'comb_state'
),
InputDesc
(
tf
.
int64
,
(
None
,),
'action'
),
InputDesc
(
tf
.
float32
,
(
None
,),
'reward'
),
InputDesc
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'next_state'
),
InputDesc
(
tf
.
bool
,
(
None
,),
'isOver'
)]
def
_get_DQN_prediction
(
self
,
image
):
...
...
@@ -108,13 +108,20 @@ class Model(ModelDesc):
return
tf
.
identity
(
Q
,
name
=
'Qvalue'
)
def
_build_graph
(
self
,
inputs
):
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
ctx
=
get_current_tower_context
()
comb_state
,
action
,
reward
,
isOver
=
inputs
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'state'
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
)
if
not
ctx
.
is_training
:
return
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'next_state'
)
action_onehot
=
tf
.
one_hot
(
action
,
NUM_ACTIONS
,
1.0
,
0.0
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
add_moving_summary
(
max_pred_reward
)
summary
.
add_moving_summary
(
max_pred_reward
)
with
tf
.
variable_scope
(
'target'
):
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
# NxA
...
...
@@ -137,7 +144,7 @@ class Model(ModelDesc):
target
-
pred_action_value
),
name
=
'cost'
)
summary
.
add_param_summary
((
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
]))
# monitor all W
add_moving_summary
(
self
.
cost
)
summary
.
add_moving_summary
(
self
.
cost
)
def
update_target_param
(
self
):
vars
=
tf
.
trainable_variables
()
...
...
@@ -164,6 +171,7 @@ def get_config():
expreplay
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
player
=
get_player
(
train
=
True
),
state_shape
=
IMAGE_SIZE
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
...
...
@@ -171,8 +179,9 @@ def get_config():
end_exploration
=
END_EXPLORATION
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
update_frequency
=
4
,
reward_clip
=
(
-
1
,
1
),
history_len
=
FRAME_HISTORY
)
history_len
=
FRAME_HISTORY
,
reward_clip
=
(
-
1
,
1
)
)
return
TrainConfig
(
dataflow
=
expreplay
,
...
...
@@ -215,7 +224,7 @@ if __name__ == '__main__':
if
args
.
task
!=
'train'
:
cfg
=
PredictConfig
(
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'state'
],
output_names
=
[
'Qvalue'
])
if
args
.
task
==
'play'
:
...
...
examples/DeepQNetwork/atari.py
View file @
0870401c
...
...
@@ -106,7 +106,7 @@ class AtariPlayer(RLEnvironment):
def
current_state
(
self
):
"""
:returns: a gray-scale (h, w
, 1
) uint8 image
:returns: a gray-scale (h, w) uint8 image
"""
ret
=
self
.
_grab_raw_image
()
# max-pooled over the last screen
...
...
@@ -119,7 +119,6 @@ class AtariPlayer(RLEnvironment):
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
ret
=
cv2
.
resize
(
ret
,
self
.
image_shape
)
ret
=
np
.
expand_dims
(
ret
,
axis
=
2
)
return
ret
.
astype
(
'uint8'
)
# to save some memory
def
get_action_space
(
self
):
...
...
examples/DeepQNetwork/common.py
View file @
0870401c
...
...
@@ -90,7 +90,7 @@ def eval_with_funcs(predict_funcs, nr_eval):
def
eval_model_multithread
(
cfg
,
nr_eval
):
func
=
get_predict_func
(
cfg
)
func
=
OfflinePredictor
(
cfg
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
...
...
examples/DeepQNetwork/expreplay.py
View file @
0870401c
This diff is collapsed.
Click to expand it.
tensorpack/utils/concurrency.py
View file @
0870401c
...
...
@@ -116,8 +116,12 @@ class ShareSessionThread(threading.Thread):
@
contextmanager
def
default_sess
(
self
):
if
self
.
_sess
:
with
self
.
_sess
.
as_default
():
yield
else
:
logger
.
warn
(
"ShareSessionThread {} wasn't under a default session!"
.
format
(
self
.
name
))
yield
def
start
(
self
):
import
tensorflow
as
tf
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment