Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
da9b1b2f
Commit
da9b1b2f
authored
May 31, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
eval with multithread
parent
b61ba3c9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
51 deletions
+67
-51
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+62
-48
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+2
-0
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+1
-1
tensorpack/predict/common.py
tensorpack/predict/common.py
+2
-2
No files found.
examples/Atari2600/DQN.py
View file @
da9b1b2f
...
@@ -5,18 +5,22 @@
...
@@ -5,18 +5,22 @@
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
os
,
sys
,
re
import
os
,
sys
,
re
,
time
import
random
import
random
import
argparse
import
argparse
from
tqdm
import
tqdm
import
subprocess
import
subprocess
import
multiprocessing
import
multiprocessing
,
threading
from
collections
import
deque
from
collections
import
deque
from
six.moves
import
queue
from
tqdm
import
tqdm
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.concurrency
import
ensure_proc_terminate
,
subproc_call
from
tensorpack.utils.concurrency
import
(
ensure_proc_terminate
,
\
subproc_call
,
StoppableThread
)
from
tensorpack.utils.stat
import
*
from
tensorpack.utils.stat
import
*
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
MultiProcessPredictWorker
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
MultiProcessPredictWorker
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
...
@@ -33,7 +37,7 @@ for atari games
...
@@ -33,7 +37,7 @@ for atari games
BATCH_SIZE
=
32
BATCH_SIZE
=
32
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
ACTION_REPEAT
=
4
ACTION_REPEAT
=
3
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
CHANNEL
=
FRAME_HISTORY
CHANNEL
=
FRAME_HISTORY
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
...
@@ -64,7 +68,7 @@ class Model(ModelDesc):
...
@@ -64,7 +68,7 @@ class Model(ModelDesc):
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
assert
NUM_ACTIONS
is
not
None
assert
NUM_ACTIONS
is
not
None
return
[
InputVar
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
return
[
InputVar
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
InputVar
(
tf
.
int
32
,
(
None
,),
'action'
),
InputVar
(
tf
.
int
64
,
(
None
,),
'action'
),
InputVar
(
tf
.
float32
,
(
None
,),
'reward'
),
InputVar
(
tf
.
float32
,
(
None
,),
'reward'
),
InputVar
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'next_state'
),
InputVar
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'next_state'
),
InputVar
(
tf
.
bool
,
(
None
,),
'isOver'
)
]
InputVar
(
tf
.
bool
,
(
None
,),
'isOver'
)
]
...
@@ -72,7 +76,7 @@ class Model(ModelDesc):
...
@@ -72,7 +76,7 @@ class Model(ModelDesc):
def
_get_DQN_prediction
(
self
,
image
,
is_training
):
def
_get_DQN_prediction
(
self
,
image
,
is_training
):
""" image: [0,255]"""
""" image: [0,255]"""
image
=
image
/
255.0
image
=
image
/
255.0
with
argscope
(
Conv2D
,
nl
=
PReLU
.
f
,
use_bias
=
True
):
with
argscope
(
Conv2D
,
nl
=
tf
.
nn
.
relu
,
use_bias
=
True
):
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
1
)
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
1
)
l
=
MaxPooling
(
'pool0'
,
l
,
2
)
l
=
MaxPooling
(
'pool0'
,
l
,
2
)
l
=
Conv2D
(
'conv1'
,
l
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
1
)
l
=
Conv2D
(
'conv1'
,
l
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
1
)
...
@@ -80,8 +84,11 @@ class Model(ModelDesc):
...
@@ -80,8 +84,11 @@ class Model(ModelDesc):
l
=
Conv2D
(
'conv2'
,
l
,
out_channel
=
64
,
kernel_shape
=
4
)
l
=
Conv2D
(
'conv2'
,
l
,
out_channel
=
64
,
kernel_shape
=
4
)
l
=
MaxPooling
(
'pool2'
,
l
,
2
)
l
=
MaxPooling
(
'pool2'
,
l
,
2
)
l
=
Conv2D
(
'conv3'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
)
l
=
Conv2D
(
'conv3'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
)
#l = MaxPooling('pool3', l, 2)
#l = Conv2D('conv4', l, out_channel=64, kernel_shape=3)
# the original arch
#l = Conv2D('conv0', image, out_channel=32, kernel_shape=8, stride=4)
#l = Conv2D('conv1', l, out_channel=64, kernel_shape=4, stride=2)
#l = Conv2D('conv2', l, out_channel=64, kernel_shape=3)
l
=
FullyConnected
(
'fc0'
,
l
,
512
,
nl
=
lambda
x
,
name
:
LeakyReLU
.
f
(
x
,
0.01
,
name
))
l
=
FullyConnected
(
'fc0'
,
l
,
512
,
nl
=
lambda
x
,
name
:
LeakyReLU
.
f
(
x
,
0.01
,
name
))
l
=
FullyConnected
(
'fct'
,
l
,
out_dim
=
NUM_ACTIONS
,
nl
=
tf
.
identity
)
l
=
FullyConnected
(
'fct'
,
l
,
out_dim
=
NUM_ACTIONS
,
nl
=
tf
.
identity
)
...
@@ -101,11 +108,11 @@ class Model(ModelDesc):
...
@@ -101,11 +108,11 @@ class Model(ModelDesc):
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
,
False
)
# NxA
targetQ_predict_value
=
self
.
_get_DQN_prediction
(
next_state
,
False
)
# NxA
# DQN
# DQN
best_v
=
tf
.
reduce_max
(
targetQ_predict_value
,
1
)
# N,
#
best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
# Double-DQN
# Double-DQN
#
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
predict_onehot
=
tf
.
one_hot
(
self
.
greedy_choice
,
NUM_ACTIONS
,
1.0
,
0.0
)
#
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
best_v
=
tf
.
reduce_sum
(
targetQ_predict_value
*
predict_onehot
,
1
)
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
...
@@ -156,7 +163,7 @@ def play_one_episode(player, func, verbose=False):
...
@@ -156,7 +163,7 @@ def play_one_episode(player, func, verbose=False):
return
sc
return
sc
def
play_model
(
model_path
):
def
play_model
(
model_path
):
player
=
PreventStuckPlayer
(
HistoryFramePlayer
(
get_player
(
0.01
),
FRAME_HISTORY
),
30
,
1
)
player
=
PreventStuckPlayer
(
HistoryFramePlayer
(
get_player
(
0.01
3
),
FRAME_HISTORY
),
30
,
1
)
cfg
=
PredictConfig
(
cfg
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
input_data_mapping
=
[
0
],
input_data_mapping
=
[
0
],
...
@@ -167,54 +174,61 @@ def play_model(model_path):
...
@@ -167,54 +174,61 @@ def play_model(model_path):
score
=
play_one_episode
(
player
,
predfunc
)
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
print
(
"Total:"
,
score
)
def
eval_model_multiprocess
(
model_path
):
def
eval_with_funcs
(
predict_funcs
):
cfg
=
PredictConfig
(
class
Worker
(
StoppableThread
):
model
=
Model
(),
def
__init__
(
self
,
func
,
queue
):
input_data_mapping
=
[
0
],
super
(
Worker
,
self
)
.
__init__
()
session_init
=
SaverRestore
(
model_path
),
self
.
func
=
func
output_var_names
=
[
'fct/output:0'
])
self
.
q
=
queue
class
Worker
(
MultiProcessPredictWorker
):
def
__init__
(
self
,
idx
,
gpuid
,
config
,
outqueue
):
super
(
Worker
,
self
)
.
__init__
(
idx
,
gpuid
,
config
)
self
.
outq
=
outqueue
def
run
(
self
):
def
run
(
self
):
player
=
PreventStuckPlayer
(
HistoryFramePlayer
(
get_player
(),
FRAME_HISTORY
),
30
,
1
)
player
=
PreventStuckPlayer
(
HistoryFramePlayer
(
get_player
(),
FRAME_HISTORY
),
30
,
1
)
self
.
_init_runtime
()
while
not
self
.
stopped
():
while
True
:
score
=
play_one_episode
(
player
,
self
.
func
)
score
=
play_one_episode
(
player
,
self
.
func
)
self
.
outq
.
put
(
score
)
while
not
self
.
stopped
():
try
:
self
.
q
.
put
(
score
,
timeout
=
5
)
break
except
queue
.
Queue
.
Full
:
pass
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
q
=
queue
.
Queue
()
q
=
multiprocessing
.
Queue
()
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict_funcs
]
gpuid
=
get_gpus
()[
0
]
procs
=
[
Worker
(
k
,
gpuid
,
cfg
,
q
)
for
k
in
range
(
NR_PROC
)]
for
k
in
threads
:
ensure_proc_terminate
(
procs
)
for
k
in
procs
:
k
.
start
()
k
.
start
()
time
.
sleep
(
0.1
)
# avoid simulator bugs
stat
=
StatCounter
()
stat
=
StatCounter
()
try
:
try
:
for
_
in
tqdm
(
range
(
EVAL_EPISODE
)):
for
_
in
tqdm
(
range
(
EVAL_EPISODE
)):
r
=
q
.
get
()
r
=
q
.
get
()
stat
.
feed
(
r
)
stat
.
feed
(
r
)
for
k
in
threads
:
k
.
stop
()
for
k
in
threads
:
k
.
join
()
finally
:
finally
:
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
return
(
stat
.
average
,
stat
.
max
)
stat
.
average
,
stat
.
max
))
def
eval_model_multithread
(
model_path
):
cfg
=
PredictConfig
(
model
=
Model
(),
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
p
=
get_player
();
del
p
# set NUM_ACTIONS
func
=
get_predict_func
(
cfg
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
class
Evaluator
(
Callback
):
class
Evaluator
(
Callback
):
def
_before_train
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
[
'state'
],
[
'fct/output'
])]
*
NR_PROC
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
logger
.
info
(
"Evaluating..."
)
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
)
output
=
subproc_call
(
"{} --task eval --rom {} --load {}"
.
format
(
sys
.
argv
[
0
],
ROM_FILE
,
os
.
path
.
join
(
logger
.
LOG_DIR
,
'checkpoint'
)),
timeout
=
10
*
60
)
if
output
:
last
=
output
.
strip
()
.
split
(
'
\n
'
)[
-
1
]
last
=
last
[
last
.
find
(
']'
)
+
1
:]
mean
,
maximum
=
re
.
findall
(
'[0-9
\
.
\
-]+'
,
last
)[
-
2
:]
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
maximum
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
max
)
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
basename
=
os
.
path
.
basename
(
__file__
)
...
@@ -277,7 +291,7 @@ if __name__ == '__main__':
...
@@ -277,7 +291,7 @@ if __name__ == '__main__':
play_model
(
args
.
load
)
play_model
(
args
.
load
)
sys
.
exit
()
sys
.
exit
()
if
args
.
task
==
'eval'
:
if
args
.
task
==
'eval'
:
eval_model_multi
process
(
args
.
load
)
eval_model_multi
thread
(
args
.
load
)
sys
.
exit
()
sys
.
exit
()
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
...
...
tensorpack/RL/atari.py
View file @
da9b1b2f
...
@@ -42,6 +42,8 @@ class AtariPlayer(RLEnvironment):
...
@@ -42,6 +42,8 @@ class AtariPlayer(RLEnvironment):
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setBool
(
"showinfo"
,
False
)
#ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
self
.
ale
.
setInt
(
"frame_skip"
,
1
)
self
.
ale
.
setInt
(
"frame_skip"
,
1
)
self
.
ale
.
setBool
(
'color_averaging'
,
False
)
self
.
ale
.
setBool
(
'color_averaging'
,
False
)
# manual.pdf suggests otherwise. may need to check
# manual.pdf suggests otherwise. may need to check
...
...
tensorpack/callbacks/group.py
View file @
da9b1b2f
...
@@ -89,7 +89,7 @@ class CallbackTimeLogger(object):
...
@@ -89,7 +89,7 @@ class CallbackTimeLogger(object):
msgs
=
[]
msgs
=
[]
for
name
,
t
in
self
.
times
:
for
name
,
t
in
self
.
times
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{:.3f}sec"
.
format
(
name
,
t
))
msgs
.
append
(
"{}:
{:.3f}sec"
.
format
(
name
,
t
))
logger
.
info
(
logger
.
info
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
'; '
.
join
(
msgs
)))
self
.
tot
,
'; '
.
join
(
msgs
)))
...
...
tensorpack/predict/common.py
View file @
da9b1b2f
...
@@ -79,8 +79,8 @@ def get_predict_func(config):
...
@@ -79,8 +79,8 @@ def get_predict_func(config):
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
for
n
in
output_var_names
]
for
n
in
output_var_names
]
# start with minimal memory, but allow growth
#
XXX does it work?
start with minimal memory, but allow growth
sess
=
tf
.
Session
(
config
=
get_default_sess_config
(
0.
01
))
sess
=
tf
.
Session
(
config
=
get_default_sess_config
(
0.
3
))
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
def
run_input
(
dp
):
def
run_input
(
dp
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment