Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6452b444
Commit
6452b444
authored
Feb 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update DQN code
parent
0870401c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
53 additions
and
42 deletions
+53
-42
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+31
-26
examples/DeepQNetwork/README.md
examples/DeepQNetwork/README.md
+4
-4
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+3
-4
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+3
-5
scripts/ls-checkpoint.py
scripts/ls-checkpoint.py
+12
-3
No files found.
examples/DeepQNetwork/DQN.py
View file @
6452b444
...
@@ -43,7 +43,7 @@ MEMORY_SIZE = 1e6
...
@@ -43,7 +43,7 @@ MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
# Suggest using tcmalloc to manage memory space better.
# Suggest using tcmalloc to manage memory space better.
INIT_MEMORY_SIZE
=
5e4
INIT_MEMORY_SIZE
=
5e4
STEP_PER_EPOCH
=
10000
STEP
S
_PER_EPOCH
=
10000
EVAL_EPISODE
=
50
EVAL_EPISODE
=
50
NUM_ACTIONS
=
None
NUM_ACTIONS
=
None
...
@@ -54,8 +54,6 @@ METHOD = None
...
@@ -54,8 +54,6 @@ METHOD = None
def
get_player
(
viz
=
False
,
train
=
False
):
def
get_player
(
viz
=
False
,
train
=
False
):
pl
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
pl
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_action_space
()
.
num_actions
()
if
not
train
:
if
not
train
:
pl
=
MapPlayerState
(
pl
,
lambda
im
:
im
[:,
:,
np
.
newaxis
])
pl
=
MapPlayerState
(
pl
,
lambda
im
:
im
[:,
:,
np
.
newaxis
])
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
...
@@ -69,9 +67,8 @@ common.get_player = get_player # so that eval functions in common can use the p
...
@@ -69,9 +67,8 @@ common.get_player = get_player # so that eval functions in common can use the p
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
def
_get_inputs
(
self
):
if
NUM_ACTIONS
is
None
:
# use a combined state, where the first channels are the current state,
p
=
get_player
()
# and the last 4 channels are the next state
del
p
return
[
InputDesc
(
tf
.
uint8
,
return
[
InputDesc
(
tf
.
uint8
,
(
None
,)
+
IMAGE_SIZE
+
(
CHANNEL
+
1
,),
(
None
,)
+
IMAGE_SIZE
+
(
CHANNEL
+
1
,),
'comb_state'
),
'comb_state'
),
...
@@ -102,28 +99,31 @@ class Model(ModelDesc):
...
@@ -102,28 +99,31 @@ class Model(ModelDesc):
if
METHOD
!=
'Dueling'
:
if
METHOD
!=
'Dueling'
:
Q
=
FullyConnected
(
'fct'
,
l
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)
Q
=
FullyConnected
(
'fct'
,
l
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)
else
:
else
:
# Dueling DQN
V
=
FullyConnected
(
'fctV'
,
l
,
1
,
nl
=
tf
.
identity
)
V
=
FullyConnected
(
'fctV'
,
l
,
1
,
nl
=
tf
.
identity
)
As
=
FullyConnected
(
'fctA'
,
l
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)
As
=
FullyConnected
(
'fctA'
,
l
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)
Q
=
tf
.
add
(
As
,
V
-
tf
.
reduce_mean
(
As
,
1
,
keep_dims
=
True
))
Q
=
tf
.
add
(
As
,
V
-
tf
.
reduce_mean
(
As
,
1
,
keep_dims
=
True
))
return
tf
.
identity
(
Q
,
name
=
'Qvalue'
)
return
tf
.
identity
(
Q
,
name
=
'Qvalue'
)
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
ctx
=
get_current_tower_context
()
comb_state
,
action
,
reward
,
isOver
=
inputs
comb_state
,
action
,
reward
,
isOver
=
inputs
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'state'
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'state'
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
)
if
not
ctx
.
is_training
:
if
not
get_current_tower_context
()
.
is_training
:
return
return
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'next_state'
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'next_state'
)
action_onehot
=
tf
.
one_hot
(
action
,
NUM_ACTIONS
,
1.0
,
0.0
)
action_onehot
=
tf
.
one_hot
(
action
,
NUM_ACTIONS
,
1.0
,
0.0
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
summary
.
add_moving_summary
(
max_pred_reward
)
summary
.
add_moving_summary
(
max_pred_reward
)
with
tf
.
variable_scope
(
'target'
):
with
tf
.
variable_scope
(
'target'
),
\
collection
.
freeze_collection
([
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
]):
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
# NxA
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
# NxA
if
METHOD
!=
'Double'
:
if
METHOD
!=
'Double'
:
...
@@ -146,17 +146,6 @@ class Model(ModelDesc):
...
@@ -146,17 +146,6 @@ class Model(ModelDesc):
(
'fc.*/W'
,
[
'histogram'
,
'rms'
]))
# monitor all W
(
'fc.*/W'
,
[
'histogram'
,
'rms'
]))
# monitor all W
summary
.
add_moving_summary
(
self
.
cost
)
summary
.
add_moving_summary
(
self
.
cost
)
def
update_target_param
(
self
):
vars
=
tf
.
trainable_variables
()
ops
=
[]
for
v
in
vars
:
target_name
=
v
.
op
.
name
if
target_name
.
startswith
(
'target'
):
new_name
=
target_name
.
replace
(
'target/'
,
''
)
logger
.
info
(
"{} <- {}"
.
format
(
target_name
,
new_name
))
ops
.
append
(
v
.
assign
(
tf
.
get_default_graph
()
.
get_tensor_by_name
(
new_name
+
':0'
)))
return
tf
.
group
(
*
ops
,
name
=
'update_target_network'
)
def
_get_optimizer
(
self
):
def
_get_optimizer
(
self
):
lr
=
symbf
.
get_scalar_var
(
'learning_rate'
,
1e-3
,
summary
=
True
)
lr
=
symbf
.
get_scalar_var
(
'learning_rate'
,
1e-3
,
summary
=
True
)
opt
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
)
opt
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
)
...
@@ -179,25 +168,36 @@ def get_config():
...
@@ -179,25 +168,36 @@ def get_config():
end_exploration
=
END_EXPLORATION
,
end_exploration
=
END_EXPLORATION
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
update_frequency
=
4
,
update_frequency
=
4
,
history_len
=
FRAME_HISTORY
,
history_len
=
FRAME_HISTORY
reward_clip
=
(
-
1
,
1
)
)
)
def
update_target_param
():
vars
=
tf
.
trainable_variables
()
ops
=
[]
G
=
tf
.
get_default_graph
()
for
v
in
vars
:
target_name
=
v
.
op
.
name
if
target_name
.
startswith
(
'target'
):
new_name
=
target_name
.
replace
(
'target/'
,
''
)
logger
.
info
(
"{} <- {}"
.
format
(
target_name
,
new_name
))
ops
.
append
(
v
.
assign
(
G
.
get_tensor_by_name
(
new_name
+
':0'
)))
return
tf
.
group
(
*
ops
,
name
=
'update_target_network'
)
return
TrainConfig
(
return
TrainConfig
(
dataflow
=
expreplay
,
dataflow
=
expreplay
,
callbacks
=
[
callbacks
=
[
ModelSaver
(),
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
RunOp
(
lambda
:
M
.
update_target_param
()
),
RunOp
(
update_target_param
),
expreplay
,
expreplay
,
StartProcOrThread
(
expreplay
.
get_simulator_thread
()),
PeriodicTrigger
(
Evaluator
(
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'Qvalue'
]),
3
),
EVAL_EPISODE
,
[
'state'
],
[
'Qvalue'
]),
every_k_epochs
=
5
),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
],
],
model
=
M
,
model
=
M
,
steps_per_epoch
=
STEP_PER_EPOCH
,
steps_per_epoch
=
STEP
S
_PER_EPOCH
,
# run the simulator on a separate GPU if available
# run the simulator on a separate GPU if available
predict_tower
=
[
1
]
if
get_nr_gpu
()
>
1
else
[
0
],
predict_tower
=
[
1
]
if
get_nr_gpu
()
>
1
else
[
0
],
)
)
...
@@ -221,6 +221,11 @@ if __name__ == '__main__':
...
@@ -221,6 +221,11 @@ if __name__ == '__main__':
ROM_FILE
=
args
.
rom
ROM_FILE
=
args
.
rom
METHOD
=
args
.
algo
METHOD
=
args
.
algo
# set num_actions
pl
=
AtariPlayer
(
ROM_FILE
,
viz
=
False
)
NUM_ACTIONS
=
pl
.
get_action_space
()
.
num_actions
()
del
pl
if
args
.
task
!=
'train'
:
if
args
.
task
!=
'train'
:
cfg
=
PredictConfig
(
cfg
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
...
...
examples/DeepQNetwork/README.md
View file @
6452b444
...
@@ -19,12 +19,11 @@ Claimed performance in the paper can be reproduced, on several games I've tested
...
@@ -19,12 +19,11 @@ Claimed performance in the paper can be reproduced, on several games I've tested


DQN typically took 1
.5 days
of training to reach a score of 400 on breakout game (same as the paper).
DQN typically took 1
day
of training to reach a score of 400 on breakout game (same as the paper).
My Batch-A3C implementation only took <2 hours.
My Batch-A3C implementation only took <2 hours.
Both were trained on one GPU with an extra GPU for simulation.
Both were trained on one GPU with an extra GPU for simulation.
The x-axis is the number of iterations, not wall time.
Double-DQN runs at 18 batches/s (1152 frames/s) on TitanX.
Double-DQN is faster at the beginning but will converge to 12 batches/s (768 frames/s) due of exploration annealing.
## How to use
## How to use
...
@@ -37,9 +36,10 @@ To train:
...
@@ -37,9 +36,10 @@ To train:
# use `--algo` to select other DQN algorithms. See `-h` for more options.
# use `--algo` to select other DQN algorithms. See `-h` for more options.
```
```
To
visualize the agent
:
To
watch the agent play
:
```
```
./DQN.py --rom breakout.bin --task play --load trained.model
./DQN.py --rom breakout.bin --task play --load trained.model
```
```
A pretrained model on breakout can be downloaded
[
here
](
https://drive.google.com/open?id=0B9IPQTvr2BBkN1Jrei1xWW0yR28
)
.
A3C code and models for Atari games in OpenAI Gym are released in
[
examples/A3C-Gym
](
../A3C-Gym
)
A3C code and models for Atari games in OpenAI Gym are released in
[
examples/A3C-Gym
](
../A3C-Gym
)
examples/DeepQNetwork/common.py
View file @
6452b444
...
@@ -11,7 +11,6 @@ from tqdm import tqdm
...
@@ -11,7 +11,6 @@ from tqdm import tqdm
from
six.moves
import
queue
from
six.moves
import
queue
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.predict
import
get_predict_func
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.stats
import
*
from
tensorpack.utils.stats
import
*
...
@@ -33,7 +32,7 @@ def play_one_episode(player, func, verbose=False):
...
@@ -33,7 +32,7 @@ def play_one_episode(player, func, verbose=False):
def
play_model
(
cfg
):
def
play_model
(
cfg
):
player
=
get_player
(
viz
=
0.01
)
player
=
get_player
(
viz
=
0.01
)
predfunc
=
get_predict_func
(
cfg
)
predfunc
=
OfflinePredictor
(
cfg
)
while
True
:
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
print
(
"Total:"
,
score
)
...
@@ -96,7 +95,7 @@ def eval_model_multithread(cfg, nr_eval):
...
@@ -96,7 +95,7 @@ def eval_model_multithread(cfg, nr_eval):
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
class
Evaluator
(
Callback
):
class
Evaluator
(
Triggerable
):
def
__init__
(
self
,
nr_eval
,
input_names
,
output_names
):
def
__init__
(
self
,
nr_eval
,
input_names
,
output_names
):
self
.
eval_episode
=
nr_eval
self
.
eval_episode
=
nr_eval
self
.
input_names
=
input_names
self
.
input_names
=
input_names
...
@@ -107,7 +106,7 @@ class Evaluator(Callback):
...
@@ -107,7 +106,7 @@ class Evaluator(Callback):
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
self
.
input_names
,
self
.
output_names
)]
*
NR_PROC
self
.
input_names
,
self
.
output_names
)]
*
NR_PROC
def
_trigger
_epoch
(
self
):
def
_trigger
(
self
):
t
=
time
.
time
()
t
=
time
.
time
()
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
,
nr_eval
=
self
.
eval_episode
)
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
,
nr_eval
=
self
.
eval_episode
)
t
=
time
.
time
()
-
t
t
=
time
.
time
()
-
t
...
...
examples/DeepQNetwork/expreplay.py
View file @
6452b444
...
@@ -124,8 +124,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -124,8 +124,7 @@ class ExpReplay(DataFlow, Callback):
batch_size
,
batch_size
,
memory_size
,
init_memory_size
,
memory_size
,
init_memory_size
,
exploration
,
end_exploration
,
exploration_epoch_anneal
,
exploration
,
end_exploration
,
exploration_epoch_anneal
,
update_frequency
,
history_len
,
update_frequency
,
history_len
):
reward_clip
=
None
):
"""
"""
Args:
Args:
predictor_io_names (tuple of list of str): input/output names to
predictor_io_names (tuple of list of str): input/output names to
...
@@ -191,8 +190,6 @@ class ExpReplay(DataFlow, Callback):
...
@@ -191,8 +190,6 @@ class ExpReplay(DataFlow, Callback):
q_values
=
self
.
predictor
([[
history
]])[
0
][
0
]
q_values
=
self
.
predictor
([[
history
]])[
0
][
0
]
act
=
np
.
argmax
(
q_values
)
act
=
np
.
argmax
(
q_values
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
if
self
.
reward_clip
:
reward
=
np
.
clip
(
reward
,
self
.
reward_clip
[
0
],
self
.
reward_clip
[
1
])
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
debug_sample
(
self
,
sample
):
def
debug_sample
(
self
,
sample
):
...
@@ -236,7 +233,8 @@ class ExpReplay(DataFlow, Callback):
...
@@ -236,7 +233,8 @@ class ExpReplay(DataFlow, Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_init_memory
()
self
.
_init_memory
()
# TODO start thread here
self
.
_simulator_th
=
self
.
get_simulator_thread
()
self
.
_simulator_th
.
start
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
exploration
>
self
.
end_exploration
:
if
self
.
exploration
>
self
.
end_exploration
:
...
...
scripts/ls-checkpoint.py
View file @
6452b444
...
@@ -4,11 +4,20 @@
...
@@ -4,11 +4,20 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
import
six
import
sys
import
sys
import
pprint
import
pprint
from
tensorpack.tfutils.varmanip
import
get_checkpoint_path
from
tensorpack.tfutils.varmanip
import
get_checkpoint_path
path
=
get_checkpoint_path
(
sys
.
argv
[
1
])
fpath
=
sys
.
argv
[
1
]
reader
=
tf
.
train
.
NewCheckpointReader
(
path
)
pprint
.
pprint
(
reader
.
get_variable_to_shape_map
())
if
fpath
.
endswith
(
'.npy'
):
params
=
np
.
load
(
fpath
,
encoding
=
'latin1'
)
.
item
()
dic
=
{
k
:
v
.
shape
for
k
,
v
in
six
.
iteritems
(
params
)}
else
:
path
=
get_checkpoint_path
(
sys
.
argv
[
1
])
reader
=
tf
.
train
.
NewCheckpointReader
(
path
)
dic
=
reader
.
get_variable_to_shape_map
()
pprint
.
pprint
(
dic
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment