Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a1d1a4ae
Commit
a1d1a4ae
authored
May 16, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
DQN for atari
parent
53b6112d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
406 additions
and
1 deletion
+406
-1
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+297
-0
examples/Atari2600/exp_replay.py
examples/Atari2600/exp_replay.py
+108
-0
tensorpack/models/fc.py
tensorpack/models/fc.py
+1
-1
No files found.
examples/Atari2600/DQN.py
0 → 100755
View file @
a1d1a4ae
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: DQN.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
import
os
,
sys
import
random
import
argparse
from
tqdm
import
tqdm
import
multiprocessing
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.concurrency
import
ensure_proc_terminate
from
tensorpack.utils.stat
import
*
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
ParallelPredictWorker
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
from
exp_replay
import
AtariExpReplay
"""
Implement DQN in:
Human-level control through deep reinforcement learning
for atari games
"""
BATCH_SIZE
=
32
IMAGE_SIZE
=
84
NUM_ACTIONS
=
None
FRAME_HISTORY
=
4
ACTION_REPEAT
=
3
GAMMA
=
0.99
BATCH_SIZE
=
32
INIT_EXPLORATION
=
1
EXPLORATION_EPOCH_ANNEAL
=
0.0025
END_EXPLORATION
=
0.1
INIT_MEMORY_SIZE
=
50000
MEMORY_SIZE
=
1e6
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
assert
NUM_ACTIONS
is
not
None
return
[
InputVar
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
FRAME_HISTORY
),
'state'
),
InputVar
(
tf
.
int32
,
(
None
,),
'action'
),
InputVar
(
tf
.
float32
,
(
None
,),
'reward'
),
InputVar
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
FRAME_HISTORY
),
'next_state'
),
InputVar
(
tf
.
bool
,
(
None
,),
'isOver'
)
]
def
_get_DQN_prediction
(
self
,
image
,
is_training
):
""" image: [0,255]"""
image
=
image
/
128.0
-
1
with
argscope
(
Conv2D
,
nl
=
tf
.
nn
.
relu
,
use_bias
=
True
):
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
2
)
l
=
Conv2D
(
'conv1'
,
l
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
2
)
l
=
Conv2D
(
'conv2'
,
l
,
out_channel
=
64
,
kernel_shape
=
4
,
stride
=
2
)
l
=
Conv2D
(
'conv3'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
)
l
=
FullyConnected
(
'fc0'
,
l
,
512
)
l
=
FullyConnected
(
'fct'
,
l
,
out_dim
=
NUM_ACTIONS
,
nl
=
tf
.
identity
,
summary_activation
=
False
)
return
l
def
_build_graph
(
self
,
inputs
,
is_training
):
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
,
is_training
)
action_onehot
=
symbf
.
one_hot
(
action
,
NUM_ACTIONS
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
#Nx1
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
max_pred_reward
)
with
tf
.
variable_scope
(
'target'
):
targetQ_predict_value
=
tf
.
stop_gradient
(
self
.
_get_DQN_prediction
(
next_state
,
False
))
# NxA
target
=
tf
.
select
(
isOver
,
reward
,
reward
+
GAMMA
*
tf
.
reduce_max
(
targetQ_predict_value
,
1
))
# Nx1
sqrcost
=
tf
.
square
(
target
-
pred_action_value
)
abscost
=
tf
.
abs
(
target
-
pred_action_value
)
# robust error func
cost
=
tf
.
select
(
abscost
<
1
,
sqrcost
,
abscost
)
summary
.
add_param_summary
([(
'.*/W'
,
[
'histogram'
])])
# monitor histogram of all W
self
.
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cost'
)
def
update_target_param
(
self
):
vars
=
tf
.
trainable_variables
()
ops
=
[]
for
v
in
vars
:
target_name
=
v
.
op
.
name
if
target_name
.
startswith
(
'target'
):
new_name
=
target_name
.
replace
(
'target/'
,
''
)
logger
.
info
(
"{} <- {}"
.
format
(
target_name
,
new_name
))
ops
.
append
(
v
.
assign
(
tf
.
get_default_graph
()
.
get_tensor_by_name
(
new_name
+
':0'
)))
return
tf
.
group
(
*
ops
)
def
get_gradient_processor
(
self
):
return
[
MapGradient
(
lambda
grad
:
\
tf
.
clip_by_global_norm
([
grad
],
5
)[
0
][
0
]),
SummaryGradient
()]
def
current_predictor
(
state
):
pred_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
'fct/output:0'
)
pred
=
pred_var
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})
return
pred
[
0
]
class
TargetNetworkUpdator
(
Callback
):
def
__init__
(
self
,
M
):
self
.
M
=
M
def
_setup_graph
(
self
):
self
.
update_op
=
self
.
M
.
update_target_param
()
def
_update
(
self
):
logger
.
info
(
"Delayed Predictor updating..."
)
self
.
update_op
.
run
()
def
_before_train
(
self
):
self
.
_update
()
def
_trigger_epoch
(
self
):
self
.
_update
()
class
ExpReplayController
(
Callback
):
def
__init__
(
self
,
d
):
self
.
d
=
d
def
_before_train
(
self
):
self
.
d
.
init_memory
()
def
_trigger_epoch
(
self
):
if
self
.
d
.
exploration
>
END_EXPLORATION
:
self
.
d
.
exploration
-=
EXPLORATION_EPOCH_ANNEAL
logger
.
info
(
"Exploration: {}"
.
format
(
self
.
d
.
exploration
))
def
play_model
(
model_path
,
romfile
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0.01
),
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
M
=
Model
()
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
predfunc
=
get_predict_func
(
cfg
)
tot_reward
=
0
while
True
:
s
=
player
.
current_state
()
outputs
=
predfunc
([[
s
]])
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
print
action_value
,
act
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
player
.
driver
.
get_num_actions
()))
print
(
act
)
_
,
reward
,
isOver
=
player
.
action
(
act
)
tot_reward
+=
reward
if
isOver
:
print
(
"Total:"
,
tot_reward
)
tot_reward
=
0
pbar
.
update
()
def
eval_model_multiprocess
(
model_path
,
romfile
):
M
=
Model
()
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
class
Worker
(
ParallelPredictWorker
):
def
__init__
(
self
,
idx
,
gpuid
,
config
,
outqueue
):
super
(
Worker
,
self
)
.
__init__
(
idx
,
gpuid
,
config
)
self
.
outq
=
outqueue
def
run
(
self
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0
),
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
self
.
_init_runtime
()
tot_reward
=
0
while
True
:
s
=
player
.
current_state
()
outputs
=
self
.
func
([[
s
]])
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
#print action_value, act
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
player
.
driver
.
get_num_actions
()))
#print(act)
_
,
reward
,
isOver
=
player
.
action
(
act
)
tot_reward
+=
reward
if
isOver
:
self
.
outq
.
put
(
tot_reward
)
tot_reward
=
0
NR_PROC
=
multiprocessing
.
cpu_count
()
//
2
procs
=
[]
q
=
multiprocessing
.
Queue
()
for
k
in
range
(
NR_PROC
):
procs
.
append
(
Worker
(
k
,
-
1
,
cfg
,
q
))
ensure_proc_terminate
(
procs
)
for
k
in
procs
:
k
.
start
()
stat
=
StatCounter
()
EVAL_EPISODE
=
50
with
tqdm
(
total
=
EVAL_EPISODE
)
as
pbar
:
while
True
:
r
=
q
.
get
()
stat
.
feed
(
r
)
pbar
.
update
()
if
stat
.
count
()
==
EVAL_EPISODE
:
logger
.
info
(
"Average Score: {}. Max Score: {}"
.
format
(
stat
.
average
,
stat
.
max
))
break
def
get_config
(
romfile
):
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
M
=
Model
()
driver
=
AtariDriver
(
romfile
)
global
NUM_ACTIONS
NUM_ACTIONS
=
driver
.
get_num_actions
()
dataset_train
=
AtariExpReplay
(
predictor
=
current_predictor
,
player
=
AtariPlayer
(
driver
,
hist_len
=
FRAME_HISTORY
,
action_repeat
=
ACTION_REPEAT
),
memory_size
=
MEMORY_SIZE
,
batch_size
=
BATCH_SIZE
,
populate_size
=
INIT_MEMORY_SIZE
,
exploration
=
INIT_EXPLORATION
)
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
TargetNetworkUpdator
(
M
),
ExpReplayController
(
dataset_train
)
]),
session_config
=
get_default_sess_config
(
0.5
),
model
=
M
,
step_per_epoch
=
10000
,
max_epoch
=
10000
,
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--task'
,
help
=
'task to perform'
,
choices
=
[
'play'
,
'eval'
,
'train'
],
default
=
'train'
)
parser
.
add_argument
(
'--rom'
,
help
=
'atari rom'
,
required
=
True
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
if
args
.
task
==
'play'
:
play_model
(
args
.
load
,
args
.
rom
)
sys
.
exit
()
if
args
.
task
==
'eval'
:
eval_model_multiprocess
(
args
.
load
,
args
.
rom
)
sys
.
exit
()
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
(
args
.
rom
)
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
SimpleTrainer
(
config
)
.
train
()
examples/Atari2600/exp_replay.py
0 → 100755
View file @
a1d1a4ae
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: exp_replay.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
from
tensorpack.utils
import
*
from
tqdm
import
tqdm
import
random
import
numpy
as
np
import
cv2
from
collections
import
deque
,
namedtuple
Experience
=
namedtuple
(
'Experience'
,
[
'state'
,
'action'
,
'reward'
,
'next'
,
'isOver'
])
def
view_state
(
state
):
r
=
np
.
concatenate
([
state
[:,:,
k
]
for
k
in
range
(
state
.
shape
[
2
])],
axis
=
1
)
print
r
.
shape
cv2
.
imshow
(
"state"
,
r
)
cv2
.
waitKey
()
class
AtariExpReplay
(
DataFlow
):
"""
Implement experience replay
"""
def
__init__
(
self
,
predictor
,
player
,
memory_size
=
1e6
,
batch_size
=
32
,
populate_size
=
50000
,
exploration
=
1
):
"""
:param predictor: callabale. called with a state, return a distribution
"""
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
self
.
num_actions
=
self
.
player
.
driver
.
get_num_actions
()
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
mem
=
deque
(
maxlen
=
memory_size
)
self
.
rng
=
get_rng
(
self
)
def
init_memory
(
self
):
logger
.
info
(
"Populating replay memory..."
)
with
tqdm
(
total
=
self
.
populate_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
self
.
populate_size
:
self
.
_populate_exp
()
pbar
.
update
()
def
reset_state
(
self
):
raise
RuntimeError
(
"Don't run me in multiple processes"
)
def
_populate_exp
(
self
):
p
=
self
.
rng
.
rand
()
old_s
=
self
.
player
.
current_state
()
if
p
<=
self
.
exploration
:
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
act
=
np
.
argmax
(
self
.
predictor
(
old_s
))
# TODO race condition in session?
_
,
reward
,
isOver
=
self
.
player
.
action
(
act
)
reward
=
np
.
clip
(
reward
,
-
1
,
2
)
s
=
self
.
player
.
current_state
()
#print act, reward
#view_state(s)
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
s
,
isOver
))
def
get_data
(
self
):
while
True
:
idxs
=
self
.
rng
.
randint
(
len
(
self
.
mem
),
size
=
self
.
batch_size
)
batch_exp
=
[
self
.
mem
[
k
]
for
k
in
idxs
]
yield
self
.
_process_batch
(
batch_exp
)
self
.
_populate_exp
()
def
_process_batch
(
self
,
batch_exp
):
state_shape
=
batch_exp
[
0
]
.
state
.
shape
state
=
np
.
zeros
((
self
.
batch_size
,
)
+
state_shape
,
dtype
=
'float32'
)
next_state
=
np
.
zeros
((
self
.
batch_size
,
)
+
state_shape
,
dtype
=
'float32'
)
reward
=
np
.
zeros
((
self
.
batch_size
,),
dtype
=
'float32'
)
action
=
np
.
zeros
((
self
.
batch_size
,),
dtype
=
'int32'
)
isOver
=
np
.
zeros
((
self
.
batch_size
,),
dtype
=
'bool'
)
for
idx
,
b
in
enumerate
(
batch_exp
):
state
[
idx
]
=
b
.
state
action
[
idx
]
=
b
.
action
next_state
[
idx
]
=
b
.
next
reward
[
idx
]
=
b
.
reward
isOver
[
idx
]
=
b
.
isOver
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
if
__name__
==
'__main__'
:
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
.
initialized
=
False
E
=
AtariExpReplay
(
predictor
,
predictor
,
AtariPlayer
(
AtariDriver
(
'../../space_invaders.bin'
,
viz
=
0.01
)),
populate_size
=
1000
)
E
.
init_memory
()
for
k
in
E
.
get_data
():
pass
#import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#break
tensorpack/models/fc.py
View file @
a1d1a4ae
...
@@ -31,7 +31,7 @@ def FullyConnected(x, out_dim,
...
@@ -31,7 +31,7 @@ def FullyConnected(x, out_dim,
if
W_init
is
None
:
if
W_init
is
None
:
#W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
#W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
W_init
=
tf
.
uniform_unit_scaling_initializer
()
W_init
=
tf
.
uniform_unit_scaling_initializer
(
factor
=
1.43
)
if
b_init
is
None
:
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
b_init
=
tf
.
constant_initializer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment