Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4414d3ba
Commit
4414d3ba
authored
May 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor RL a bit
parent
249052e0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
121 additions
and
99 deletions
+121
-99
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+5
-6
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+11
-82
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+96
-0
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+9
-11
No files found.
examples/A3C-Gym/train-atari.py
View file @
4414d3ba
...
...
@@ -71,9 +71,6 @@ def get_player(viz=False, train=False, dumpdir=None):
return
pl
common
.
get_player
=
get_player
class
MySimulatorWorker
(
SimulatorProcess
):
def
_build_player
(
self
):
return
get_player
(
train
=
True
)
...
...
@@ -230,7 +227,9 @@ def get_config():
HumanHyperParamSetter
(
'entropy_beta'
),
master
,
StartProcOrThread
(
master
),
PeriodicTrigger
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'policy'
]),
every_k_epochs
=
2
),
PeriodicTrigger
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'policy'
],
get_player
),
every_k_epochs
=
2
),
],
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
...
...
@@ -280,9 +279,9 @@ if __name__ == '__main__':
input_names
=
[
'state'
],
output_names
=
[
'policy'
])
if
args
.
task
==
'play'
:
play_model
(
cfg
)
play_model
(
cfg
,
get_player
(
viz
=
0.01
)
)
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
cfg
,
args
.
episode
)
eval_model_multithread
(
cfg
,
args
.
episode
,
get_player
)
elif
args
.
task
==
'gen_submit'
:
run_submission
(
cfg
,
args
.
output
,
args
.
episode
)
else
:
...
...
examples/DeepQNetwork/DQN.py
View file @
4414d3ba
...
...
@@ -18,10 +18,10 @@ from collections import deque
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.RL
import
*
import
tensorflow
as
tf
from
DQNModel
import
Model
as
DQNModel
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
atari
import
AtariPlayer
...
...
@@ -61,20 +61,7 @@ def get_player(viz=False, train=False):
return
pl
common
.
get_player
=
get_player
# so that eval functions in common can use the player
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
# use a combined state, where the first channels are the current state,
# and the last 4 channels are the next state
return
[
InputDesc
(
tf
.
uint8
,
(
None
,)
+
IMAGE_SIZE
+
(
CHANNEL
+
1
,),
'comb_state'
),
InputDesc
(
tf
.
int64
,
(
None
,),
'action'
),
InputDesc
(
tf
.
float32
,
(
None
,),
'reward'
),
InputDesc
(
tf
.
bool
,
(
None
,),
'isOver'
)]
class
Model
(
DQNModel
):
def
_get_DQN_prediction
(
self
,
image
):
""" image: [0,255]"""
image
=
image
/
255.0
...
...
@@ -95,67 +82,20 @@ class Model(ModelDesc):
# .Conv2D('conv2', out_channel=64, kernel_shape=3)
.
FullyConnected
(
'fc0'
,
512
,
nl
=
LeakyReLU
)())
if
METHOD
!=
'Dueling'
:
Q
=
FullyConnected
(
'fct'
,
l
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)
if
self
.
method
!=
'Dueling'
:
Q
=
FullyConnected
(
'fct'
,
l
,
self
.
num_actions
,
nl
=
tf
.
identity
)
else
:
# Dueling DQN
V
=
FullyConnected
(
'fctV'
,
l
,
1
,
nl
=
tf
.
identity
)
As
=
FullyConnected
(
'fctA'
,
l
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)
As
=
FullyConnected
(
'fctA'
,
l
,
self
.
num_actions
,
nl
=
tf
.
identity
)
Q
=
tf
.
add
(
As
,
V
-
tf
.
reduce_mean
(
As
,
1
,
keep_dims
=
True
))
return
tf
.
identity
(
Q
,
name
=
'Qvalue'
)
def
_build_graph
(
self
,
inputs
):
comb_state
,
action
,
reward
,
isOver
=
inputs
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'state'
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
)
if
not
get_current_tower_context
()
.
is_training
:
return
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'next_state'
)
action_onehot
=
tf
.
one_hot
(
action
,
NUM_ACTIONS
,
1.0
,
0.0
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
summary
.
add_moving_summary
(
max_pred_reward
)
with
tf
.
variable_scope
(
'target'
),
\
collection
.
freeze_collection
([
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
]):
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
# NxA
if
METHOD
!=
'Double'
:
# DQN
best_v
=
tf
.
reduce_max
(
targetQ_predict_value
,
1
)
# N,
else
:
# Double-DQN
sc
=
tf
.
get_variable_scope
()
with
tf
.
variable_scope
(
sc
,
reuse
=
True
):
next_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
self
.
greedy_choice
=
tf
.
argmax
(
next_predict_value
,
1
)
# N,
predict_onehot
=
tf
.
one_hot
(
self
.
greedy_choice
,
NUM_ACTIONS
,
1.0
,
0.0
)
best_v
=
tf
.
reduce_sum
(
targetQ_predict_value
*
predict_onehot
,
1
)
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
self
.
cost
=
tf
.
reduce_mean
(
symbf
.
huber_loss
(
target
-
pred_action_value
),
name
=
'cost'
)
summary
.
add_param_summary
((
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
]))
# monitor all W
summary
.
add_moving_summary
(
self
.
cost
)
def
_get_optimizer
(
self
):
lr
=
symbf
.
get_scalar_var
(
'learning_rate'
,
1e-3
,
summary
=
True
)
opt
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
)
return
optimizer
.
apply_grad_processors
(
opt
,
[
gradproc
.
GlobalNormClip
(
10
),
gradproc
.
SummaryGradient
()])
def
get_config
():
logger
.
auto_set_dir
()
M
=
Model
()
M
=
Model
(
IMAGE_SIZE
,
CHANNEL
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
expreplay
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
player
=
get_player
(
train
=
True
),
...
...
@@ -170,28 +110,17 @@ def get_config():
history_len
=
FRAME_HISTORY
)
def
update_target_param
():
vars
=
tf
.
global_variables
()
ops
=
[]
G
=
tf
.
get_default_graph
()
for
v
in
vars
:
target_name
=
v
.
op
.
name
if
target_name
.
startswith
(
'target'
):
new_name
=
target_name
.
replace
(
'target/'
,
''
)
logger
.
info
(
"{} <- {}"
.
format
(
target_name
,
new_name
))
ops
.
append
(
v
.
assign
(
G
.
get_tensor_by_name
(
new_name
+
':0'
)))
return
tf
.
group
(
*
ops
,
name
=
'update_target_network'
)
return
TrainConfig
(
dataflow
=
expreplay
,
callbacks
=
[
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
150
,
4e-4
),
(
250
,
1e-4
),
(
350
,
5e-5
)]),
RunOp
(
update_target_param
),
RunOp
(
DQNModel
.
update_target_param
),
expreplay
,
PeriodicTrigger
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'Qvalue'
]),
every_k_epochs
=
5
),
EVAL_EPISODE
,
[
'state'
],
[
'Qvalue'
],
get_player
),
every_k_epochs
=
5
),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
],
...
...
@@ -232,9 +161,9 @@ if __name__ == '__main__':
input_names
=
[
'state'
],
output_names
=
[
'Qvalue'
])
if
args
.
task
==
'play'
:
play_model
(
cfg
)
play_model
(
cfg
,
get_player
(
viz
=
0.01
)
)
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
cfg
,
EVAL_EPISODE
)
eval_model_multithread
(
cfg
,
EVAL_EPISODE
,
get_player
)
else
:
config
=
get_config
()
if
args
.
load
:
...
...
examples/DeepQNetwork/DQNModel.py
0 → 100644
View file @
4414d3ba
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: DQNModel.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
abc
import
tensorflow
as
tf
from
tensorpack
import
ModelDesc
,
InputDesc
from
tensorpack.utils
import
logger
from
tensorpack.tfutils
import
(
collection
,
summary
,
get_current_tower_context
,
optimizer
,
gradproc
)
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
class
Model
(
ModelDesc
):
def
__init__
(
self
,
image_shape
,
channel
,
method
,
num_actions
,
gamma
):
self
.
image_shape
=
image_shape
self
.
channel
=
channel
self
.
method
=
method
self
.
num_actions
=
num_actions
self
.
gamma
=
gamma
def
_get_inputs
(
self
):
# use a combined state, where the first channels are the current state,
# and the last 4 channels are the next state
return
[
InputDesc
(
tf
.
uint8
,
(
None
,)
+
self
.
image_shape
+
(
self
.
channel
+
1
,),
'comb_state'
),
InputDesc
(
tf
.
int64
,
(
None
,),
'action'
),
InputDesc
(
tf
.
float32
,
(
None
,),
'reward'
),
InputDesc
(
tf
.
bool
,
(
None
,),
'isOver'
)]
@
abc
.
abstractmethod
def
_get_DQN_prediction
(
self
,
image
):
pass
def
_build_graph
(
self
,
inputs
):
comb_state
,
action
,
reward
,
isOver
=
inputs
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'state'
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
)
if
not
get_current_tower_context
()
.
is_training
:
return
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'next_state'
)
action_onehot
=
tf
.
one_hot
(
action
,
self
.
num_actions
,
1.0
,
0.0
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
summary
.
add_moving_summary
(
max_pred_reward
)
with
tf
.
variable_scope
(
'target'
),
\
collection
.
freeze_collection
([
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
]):
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
# NxA
if
self
.
method
!=
'Double'
:
# DQN
best_v
=
tf
.
reduce_max
(
targetQ_predict_value
,
1
)
# N,
else
:
# Double-DQN
sc
=
tf
.
get_variable_scope
()
with
tf
.
variable_scope
(
sc
,
reuse
=
True
):
next_predict_value
=
self
.
_get_DQN_prediction
(
next_state
)
self
.
greedy_choice
=
tf
.
argmax
(
next_predict_value
,
1
)
# N,
predict_onehot
=
tf
.
one_hot
(
self
.
greedy_choice
,
self
.
num_actions
,
1.0
,
0.0
)
best_v
=
tf
.
reduce_sum
(
targetQ_predict_value
*
predict_onehot
,
1
)
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
self
.
gamma
*
tf
.
stop_gradient
(
best_v
)
self
.
cost
=
tf
.
reduce_mean
(
symbf
.
huber_loss
(
target
-
pred_action_value
),
name
=
'cost'
)
summary
.
add_param_summary
((
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
]))
# monitor all W
summary
.
add_moving_summary
(
self
.
cost
)
def
_get_optimizer
(
self
):
lr
=
symbf
.
get_scalar_var
(
'learning_rate'
,
1e-3
,
summary
=
True
)
opt
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
)
return
optimizer
.
apply_grad_processors
(
opt
,
[
gradproc
.
GlobalNormClip
(
10
),
gradproc
.
SummaryGradient
()])
@
staticmethod
def
update_target_param
():
vars
=
tf
.
global_variables
()
ops
=
[]
G
=
tf
.
get_default_graph
()
for
v
in
vars
:
target_name
=
v
.
op
.
name
if
target_name
.
startswith
(
'target'
):
new_name
=
target_name
.
replace
(
'target/'
,
''
)
logger
.
info
(
"{} <- {}"
.
format
(
target_name
,
new_name
))
ops
.
append
(
v
.
assign
(
G
.
get_tensor_by_name
(
new_name
+
':0'
)))
return
tf
.
group
(
*
ops
,
name
=
'update_target_network'
)
examples/DeepQNetwork/common.py
View file @
4414d3ba
...
...
@@ -14,9 +14,6 @@ from tensorpack import *
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.stats
import
*
global
get_player
get_player
=
None
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
def
f
(
s
):
...
...
@@ -30,15 +27,14 @@ def play_one_episode(player, func, verbose=False):
return
np
.
mean
(
player
.
play_one_episode
(
f
))
def
play_model
(
cfg
):
player
=
get_player
(
viz
=
0.01
)
def
play_model
(
cfg
,
player
):
predfunc
=
OfflinePredictor
(
cfg
)
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
def
eval_with_funcs
(
predictors
,
nr_eval
):
def
eval_with_funcs
(
predictors
,
nr_eval
,
get_player_fn
):
class
Worker
(
StoppableThread
,
ShareSessionThread
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
...
...
@@ -52,7 +48,7 @@ def eval_with_funcs(predictors, nr_eval):
def
run
(
self
):
with
self
.
default_sess
():
player
=
get_player
(
train
=
False
)
player
=
get_player
_fn
(
train
=
False
)
while
not
self
.
stopped
():
try
:
score
=
play_one_episode
(
player
,
self
.
func
)
...
...
@@ -88,18 +84,19 @@ def eval_with_funcs(predictors, nr_eval):
return
(
0
,
0
)
def
eval_model_multithread
(
cfg
,
nr_eval
):
def
eval_model_multithread
(
cfg
,
nr_eval
,
get_player_fn
):
func
=
OfflinePredictor
(
cfg
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
,
get_player_fn
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
class
Evaluator
(
Triggerable
):
def
__init__
(
self
,
nr_eval
,
input_names
,
output_names
):
def
__init__
(
self
,
nr_eval
,
input_names
,
output_names
,
get_player_fn
):
self
.
eval_episode
=
nr_eval
self
.
input_names
=
input_names
self
.
output_names
=
output_names
self
.
get_player_fn
=
get_player_fn
def
_setup_graph
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
20
)
...
...
@@ -108,7 +105,8 @@ class Evaluator(Triggerable):
def
_trigger
(
self
):
t
=
time
.
time
()
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
,
nr_eval
=
self
.
eval_episode
)
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
,
self
.
eval_episode
,
self
.
get_player_fn
)
t
=
time
.
time
()
-
t
if
t
>
10
*
60
:
# eval takes too long
self
.
eval_episode
=
int
(
self
.
eval_episode
*
0.94
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment