Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
80a110e2
Commit
80a110e2
authored
Aug 16, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[DQN] Let state have channels.
parent
f227f45f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
21 deletions
+36
-21
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+1
-2
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+5
-4
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+17
-8
examples/DeepQNetwork/README.md
examples/DeepQNetwork/README.md
+2
-3
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+11
-4
No files found.
examples/A3C-Gym/train-atari.py
View file @
80a110e2
...
@@ -278,9 +278,8 @@ if __name__ == '__main__':
...
@@ -278,9 +278,8 @@ if __name__ == '__main__':
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
ENV_NAME
=
args
.
env
ENV_NAME
=
args
.
env
logger
.
info
(
"Environment Name: {}"
.
format
(
ENV_NAME
))
NUM_ACTIONS
=
get_player
()
.
action_space
.
n
NUM_ACTIONS
=
get_player
()
.
action_space
.
n
logger
.
info
(
"
Number of actions: {}"
.
format
(
NUM_ACTIONS
))
logger
.
info
(
"
Environment: {}, number of actions: {}"
.
format
(
ENV_NAME
,
NUM_ACTIONS
))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
...
...
examples/DeepQNetwork/DQN.py
View file @
80a110e2
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
import
os
import
os
import
argparse
import
argparse
import
cv2
import
cv2
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -40,7 +41,7 @@ def get_player(viz=False, train=False):
...
@@ -40,7 +41,7 @@ def get_player(viz=False, train=False):
env
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
env
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
live_lost_as_eoe
=
train
,
max_num_frames
=
60000
)
live_lost_as_eoe
=
train
,
max_num_frames
=
60000
)
env
=
FireResetEnv
(
env
)
env
=
FireResetEnv
(
env
)
env
=
MapState
(
env
,
lambda
im
:
cv2
.
resize
(
im
,
IMAGE_SIZE
))
env
=
MapState
(
env
,
lambda
im
:
cv2
.
resize
(
im
,
IMAGE_SIZE
)
[:,
:,
np
.
newaxis
]
)
if
not
train
:
if
not
train
:
# in training, history is taken care of in expreplay buffer
# in training, history is taken care of in expreplay buffer
env
=
FrameStack
(
env
,
FRAME_HISTORY
)
env
=
FrameStack
(
env
,
FRAME_HISTORY
)
...
@@ -49,10 +50,10 @@ def get_player(viz=False, train=False):
...
@@ -49,10 +50,10 @@ def get_player(viz=False, train=False):
class
Model
(
DQNModel
):
class
Model
(
DQNModel
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
)
.
__init__
(
IMAGE_SIZE
,
FRAME_HISTORY
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
super
(
Model
,
self
)
.
__init__
(
IMAGE_SIZE
,
1
,
FRAME_HISTORY
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
def
_get_DQN_prediction
(
self
,
image
):
def
_get_DQN_prediction
(
self
,
image
):
""" image: [0,255]"""
""" image: [
N, H, W, C * history] in [
0,255]"""
image
=
image
/
255.0
image
=
image
/
255.0
with
argscope
(
Conv2D
,
activation
=
lambda
x
:
PReLU
(
'prelu'
,
x
),
use_bias
=
True
):
with
argscope
(
Conv2D
,
activation
=
lambda
x
:
PReLU
(
'prelu'
,
x
),
use_bias
=
True
):
l
=
(
LinearWrap
(
image
)
l
=
(
LinearWrap
(
image
)
...
@@ -86,7 +87,7 @@ def get_config():
...
@@ -86,7 +87,7 @@ def get_config():
expreplay
=
ExpReplay
(
expreplay
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
player
=
get_player
(
train
=
True
),
player
=
get_player
(
train
=
True
),
state_shape
=
IMAGE_SIZE
,
state_shape
=
IMAGE_SIZE
+
(
1
,)
,
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
...
...
examples/DeepQNetwork/DQNModel.py
View file @
80a110e2
...
@@ -14,18 +14,24 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
...
@@ -14,18 +14,24 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
learning_rate
=
1e-3
learning_rate
=
1e-3
def
__init__
(
self
,
image_shape
,
channel
,
method
,
num_actions
,
gamma
):
def
__init__
(
self
,
image_shape
,
channel
,
history
,
method
,
num_actions
,
gamma
):
self
.
image_shape
=
image_shape
self
.
channel
=
channel
self
.
channel
=
channel
self
.
_shape2d
=
image_shape
self
.
_shape3d
=
image_shape
+
(
channel
,
)
self
.
_shape4d_for_prediction
=
(
-
1
,
)
+
image_shape
+
(
channel
*
history
,
)
self
.
_channel
=
channel
self
.
history
=
history
self
.
method
=
method
self
.
method
=
method
self
.
num_actions
=
num_actions
self
.
num_actions
=
num_actions
self
.
gamma
=
gamma
self
.
gamma
=
gamma
def
inputs
(
self
):
def
inputs
(
self
):
# Use a combined state for efficiency.
# When we use h history frames, the current state and the next state will have (h-1) overlapping frames.
# The first h channels are the current state, and the last h channels are the next state.
# Therefore we use a combined state for efficiency:
# The first h are the current state, and the last h are the next state.
return
[
tf
.
placeholder
(
tf
.
uint8
,
return
[
tf
.
placeholder
(
tf
.
uint8
,
(
None
,)
+
self
.
image_shape
+
(
self
.
channel
+
1
,),
(
None
,)
+
self
.
_shape2d
+
(
self
.
_channel
*
(
self
.
history
+
1
),),
'comb_state'
),
'comb_state'
),
tf
.
placeholder
(
tf
.
int64
,
(
None
,),
'action'
),
tf
.
placeholder
(
tf
.
int64
,
(
None
,),
'action'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'reward'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'reward'
),
...
@@ -35,20 +41,23 @@ class Model(ModelDesc):
...
@@ -35,20 +41,23 @@ class Model(ModelDesc):
def
_get_DQN_prediction
(
self
,
image
):
def
_get_DQN_prediction
(
self
,
image
):
pass
pass
# decorate the function
@
auto_reuse_variable_scope
@
auto_reuse_variable_scope
def
get_DQN_prediction
(
self
,
image
):
def
get_DQN_prediction
(
self
,
image
):
return
self
.
_get_DQN_prediction
(
image
)
return
self
.
_get_DQN_prediction
(
image
)
def
build_graph
(
self
,
comb_state
,
action
,
reward
,
isOver
):
def
build_graph
(
self
,
comb_state
,
action
,
reward
,
isOver
):
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
self
.
channel
],
name
=
'state'
)
comb_state
=
tf
.
reshape
(
comb_state
,
[
-
1
]
+
list
(
self
.
_shape3d
)
+
[
self
.
history
+
1
])
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
,
self
.
history
])
state
=
tf
.
reshape
(
state
,
self
.
_shape4d_for_prediction
,
name
=
'state'
)
self
.
predict_value
=
self
.
get_DQN_prediction
(
state
)
self
.
predict_value
=
self
.
get_DQN_prediction
(
state
)
if
not
get_current_tower_context
()
.
is_training
:
if
not
get_current_tower_context
()
.
is_training
:
return
return
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
self
.
channel
],
name
=
'next_state'
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
self
.
history
],
name
=
'next_state'
)
next_state
=
tf
.
reshape
(
next_state
,
self
.
_shape4d_for_prediction
)
action_onehot
=
tf
.
one_hot
(
action
,
self
.
num_actions
,
1.0
,
0.0
)
action_onehot
=
tf
.
one_hot
(
action
,
self
.
num_actions
,
1.0
,
0.0
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
...
...
examples/DeepQNetwork/README.md
View file @
80a110e2
...
@@ -20,10 +20,9 @@ Claimed performance in the paper can be reproduced, on several games I've tested
...
@@ -20,10 +20,9 @@ Claimed performance in the paper can be reproduced, on several games I've tested


On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout.
On one (Maxwell) TitanX, Double-DQN took ~18 hours of training to reach a score of 400 on breakout.
Batch-A3C implementation only took <2 hours.
Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on
(Maxwell)
TitanX.
Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on TitanX.
## How to use
## How to use
...
...
examples/DeepQNetwork/expreplay.py
View file @
80a110e2
...
@@ -25,6 +25,9 @@ class ReplayMemory(object):
...
@@ -25,6 +25,9 @@ class ReplayMemory(object):
def
__init__
(
self
,
max_size
,
state_shape
,
history_len
):
def
__init__
(
self
,
max_size
,
state_shape
,
history_len
):
self
.
max_size
=
int
(
max_size
)
self
.
max_size
=
int
(
max_size
)
self
.
state_shape
=
state_shape
self
.
state_shape
=
state_shape
self
.
_state_transpose
=
list
(
range
(
1
,
len
(
state_shape
)
+
1
))
+
[
0
]
self
.
_channel
=
state_shape
[
2
]
if
len
(
state_shape
)
==
3
else
1
self
.
_shape3d
=
(
state_shape
[
0
],
state_shape
[
1
],
self
.
_channel
*
(
history_len
+
1
))
self
.
history_len
=
int
(
history_len
)
self
.
history_len
=
int
(
history_len
)
self
.
state
=
np
.
zeros
((
self
.
max_size
,)
+
state_shape
,
dtype
=
'uint8'
)
self
.
state
=
np
.
zeros
((
self
.
max_size
,)
+
state_shape
,
dtype
=
'uint8'
)
...
@@ -62,7 +65,7 @@ class ReplayMemory(object):
...
@@ -62,7 +65,7 @@ class ReplayMemory(object):
def
sample
(
self
,
idx
):
def
sample
(
self
,
idx
):
""" return a tuple of (s,r,a,o),
""" return a tuple of (s,r,a,o),
where s is of shape
STATE_SIZE + (hist_len+1,)
"""
where s is of shape
[H, W, channel * (hist_len+1)]
"""
idx
=
(
self
.
_curr_pos
+
idx
)
%
self
.
_curr_size
idx
=
(
self
.
_curr_pos
+
idx
)
%
self
.
_curr_size
k
=
self
.
history_len
+
1
k
=
self
.
history_len
+
1
if
idx
+
k
<=
self
.
_curr_size
:
if
idx
+
k
<=
self
.
_curr_size
:
...
@@ -86,7 +89,9 @@ class ReplayMemory(object):
...
@@ -86,7 +89,9 @@ class ReplayMemory(object):
state
=
copy
.
deepcopy
(
state
)
state
=
copy
.
deepcopy
(
state
)
state
[:
k
+
1
]
.
fill
(
0
)
state
[:
k
+
1
]
.
fill
(
0
)
break
break
state
=
state
.
transpose
(
1
,
2
,
0
)
# move the first dim to the last
state
=
state
.
transpose
(
*
self
.
_state_transpose
)
state
=
state
.
reshape
(
self
.
_shape3d
)
return
(
state
,
reward
[
-
2
],
action
[
-
2
],
isOver
[
-
2
])
return
(
state
,
reward
[
-
2
],
action
[
-
2
],
isOver
[
-
2
])
def
_slice
(
self
,
arr
,
start
,
end
):
def
_slice
(
self
,
arr
,
start
,
end
):
...
@@ -130,11 +135,13 @@ class ExpReplay(DataFlow, Callback):
...
@@ -130,11 +135,13 @@ class ExpReplay(DataFlow, Callback):
predictor_io_names (tuple of list of str): input/output names to
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
predict Q value from state.
player (RLEnvironment): the player.
player (RLEnvironment): the player.
state_shape (tuple): h, w, c
history_len (int): length of history frames to concat. Zero-filled
history_len (int): length of history frames to concat. Zero-filled
initial frames.
initial frames.
update_frequency (int): number of new transitions to add to memory
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
after sampling a batch of transitions for training.
"""
"""
assert
len
(
state_shape
)
==
3
,
state_shape
init_memory_size
=
int
(
init_memory_size
)
init_memory_size
=
int
(
init_memory_size
)
for
k
,
v
in
locals
()
.
items
():
for
k
,
v
in
locals
()
.
items
():
...
@@ -195,10 +202,10 @@ class ExpReplay(DataFlow, Callback):
...
@@ -195,10 +202,10 @@ class ExpReplay(DataFlow, Callback):
# build a history state
# build a history state
history
=
self
.
mem
.
recent_state
()
history
=
self
.
mem
.
recent_state
()
history
.
append
(
old_s
)
history
.
append
(
old_s
)
history
=
np
.
stack
(
history
,
axis
=
2
)
history
=
np
.
concatenate
(
history
,
axis
=-
1
)
# assume batched network
# assume batched network
q_values
=
self
.
predictor
(
history
[
None
,
:,
:,
:]
)[
0
][
0
]
# this is the bottleneck
q_values
=
self
.
predictor
(
np
.
expand_dims
(
history
,
0
)
)[
0
][
0
]
# this is the bottleneck
act
=
np
.
argmax
(
q_values
)
act
=
np
.
argmax
(
q_values
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_game_score
.
feed
(
reward
)
self
.
_current_game_score
.
feed
(
reward
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment