Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ddf737d7
Commit
ddf737d7
authored
May 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better infra for evaluate
parent
961b0ee4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
64 deletions
+74
-64
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+57
-48
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+4
-12
tensorpack/predict.py
tensorpack/predict.py
+6
-4
tensorpack/train/base.py
tensorpack/train/base.py
+7
-0
No files found.
examples/Atari2600/DQN.py
View file @
ddf737d7
...
...
@@ -5,10 +5,11 @@
import
tensorflow
as
tf
import
numpy
as
np
import
os
,
sys
import
os
,
sys
,
re
import
random
import
argparse
from
tqdm
import
tqdm
import
subprocess
import
multiprocessing
from
collections
import
deque
...
...
@@ -22,7 +23,7 @@ from tensorpack.tfutils import symbolic_functions as symbf
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
from
exp_replay
import
Atari
ExpReplay
from
tensorpack.dataflow.RL
import
ExpReplay
"""
Implement DQN in:
...
...
@@ -44,6 +45,8 @@ END_EXPLORATION = 0.1
INIT_MEMORY_SIZE
=
50000
MEMORY_SIZE
=
1e6
STEP_PER_EPOCH
=
10000
EVAL_EPISODE
=
100
class
Model
(
ModelDesc
):
...
...
@@ -138,41 +141,47 @@ class ExpReplayController(Callback):
def
_trigger_epoch
(
self
):
if
self
.
d
.
exploration
>
END_EXPLORATION
:
self
.
d
.
exploration
-=
EXPLORATION_EPOCH_ANNEAL
logger
.
info
(
"Exploration
:
{}"
.
format
(
self
.
d
.
exploration
))
logger
.
info
(
"Exploration
changed to
{}"
.
format
(
self
.
d
.
exploration
))
def
play_model
(
model_path
,
romfile
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0.01
),
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
M
=
Model
()
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
predfunc
=
get_predict_func
(
cfg
)
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
tot_reward
=
0
que
=
deque
(
maxlen
=
30
)
while
True
:
s
=
player
.
current_state
()
outputs
=
pred
func
([[
s
]])
outputs
=
func
([[
s
]])
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
if
verbose
:
print
action_value
,
act
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
player
.
driver
.
get_num_actions
()))
if
len
(
que
)
==
que
.
maxlen
\
and
que
.
count
(
que
[
0
])
==
que
.
maxlen
:
act
=
1
act
=
1
# hack, avoid stuck
que
.
append
(
act
)
if
verbose
:
print
(
act
)
reward
,
isOver
=
player
.
action
(
act
)
tot_reward
+=
reward
if
isOver
:
print
(
"Total:"
,
tot_reward
)
tot_reward
=
0
return
tot_reward
def
play_model
(
model_path
,
romfile
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0.01
),
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
M
=
Model
()
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
predfunc
=
get_predict_func
(
cfg
)
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
def
eval_model_multiprocess
(
model_path
,
romfile
):
M
=
Model
()
...
...
@@ -192,29 +201,10 @@ def eval_model_multiprocess(model_path, romfile):
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
self
.
_init_runtime
()
tot_reward
=
0
que
=
deque
(
maxlen
=
30
)
while
True
:
s
=
player
.
current_state
()
outputs
=
self
.
func
([[
s
]])
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
#print action_value, act
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
player
.
driver
.
get_num_actions
()))
if
len
(
que
)
==
que
.
maxlen
\
and
que
.
count
(
que
[
0
])
==
que
.
maxlen
:
act
=
1
que
.
append
(
act
)
#print(act)
reward
,
isOver
=
player
.
action
(
act
)
tot_reward
+=
reward
if
isOver
:
self
.
outq
.
put
(
tot_reward
)
tot_reward
=
0
score
=
play_one_episode
(
player
,
self
.
func
)
self
.
outq
.
put
(
score
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
10
)
procs
=
[]
...
...
@@ -226,13 +216,19 @@ def eval_model_multiprocess(model_path, romfile):
k
.
start
()
stat
=
StatCounter
()
try
:
EVAL_EPISODE
=
50
for
_
in
tqdm
(
range
(
EVAL_EPISODE
)):
r
=
q
.
get
()
stat
.
feed
(
r
)
finally
:
logger
.
info
(
"Average Score: {}. Max Score: {}"
.
format
(
for
p
in
procs
:
p
.
terminate
()
p
.
join
()
if
stat
.
count
()
>
0
:
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
stat
.
average
,
stat
.
max
))
return
(
stat
.
average
,
stat
.
max
)
else
:
return
(
0
,
0
)
def
get_config
(
romfile
):
...
...
@@ -260,6 +256,18 @@ def get_config(romfile):
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
class
Evaluator
(
Callback
):
def
_trigger_epoch
(
self
):
logger
.
info
(
"Evaluating..."
)
output
=
subprocess
.
check_output
(
"""{} --task eval --rom {} --load {} 2>&1 | grep Average"""
.
format
(
sys
.
argv
[
0
],
romfile
,
os
.
path
.
join
(
logger
.
LOG_DIR
,
'checkpoint'
)),
shell
=
True
)
output
=
output
.
strip
()
output
=
output
[
output
.
find
(
']'
)
+
1
:]
mean
,
maximum
=
re
.
findall
(
'[0-9
\
.]+'
,
output
)
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
maximum
)
return
TrainConfig
(
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
...
...
@@ -269,11 +277,12 @@ def get_config(romfile):
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
TargetNetworkUpdator
(
M
),
ExpReplayController
(
dataset_train
)
ExpReplayController
(
dataset_train
),
PeriodicCallback
(
Evaluator
(),
1
),
]),
session_config
=
get_default_sess_config
(
0.5
),
model
=
M
,
step_per_epoch
=
10000
,
step_per_epoch
=
STEP_PER_EPOCH
,
max_epoch
=
10000
,
)
...
...
tensorpack/callbacks/inference.py
View file @
ddf737d7
...
...
@@ -59,12 +59,6 @@ class Inferencer(object):
def
_get_output_tensors
(
self
):
pass
def
_scalar_summary
(
self
,
name
,
val
):
self
.
trainer
.
summary_writer
.
add_summary
(
create_summary
(
name
,
val
),
get_global_step
())
self
.
trainer
.
stat_holder
.
add_stat
(
name
,
val
)
class
InferenceRunner
(
Callback
):
"""
A callback that runs different kinds of inferencer.
...
...
@@ -161,9 +155,7 @@ class ScalarStats(Inferencer):
for
stat
,
name
in
zip
(
self
.
stats
,
self
.
names
):
opname
,
_
=
get_op_var_name
(
name
)
name
=
'{}_{}'
.
format
(
self
.
prefix
,
opname
)
if
self
.
prefix
else
opname
self
.
trainer
.
summary_writer
.
add_summary
(
create_summary
(
name
,
stat
),
get_global_step
())
self
.
trainer
.
stat_holder
.
add_stat
(
name
,
stat
)
self
.
trainer
.
write_scalar_summary
(
name
,
stat
)
class
ClassificationError
(
Inferencer
):
"""
...
...
@@ -197,7 +189,7 @@ class ClassificationError(Inferencer):
self
.
err_stat
.
feed
(
wrong
,
batch_size
)
def
_after_inference
(
self
):
self
.
_scalar_summary
(
self
.
summary_name
,
self
.
err_stat
.
accuracy
)
self
.
trainer
.
write
_scalar_summary
(
self
.
summary_name
,
self
.
err_stat
.
accuracy
)
class
BinaryClassificationStats
(
Inferencer
):
...
...
@@ -221,5 +213,5 @@ class BinaryClassificationStats(Inferencer):
self
.
stat
.
feed
(
pred
,
label
)
def
_after_inference
(
self
):
self
.
_scalar_summary
(
self
.
prefix
+
'_precision'
,
self
.
stat
.
precision
)
self
.
_scalar_summary
(
self
.
prefix
+
'_recall'
,
self
.
stat
.
recall
)
self
.
trainer
.
write
_scalar_summary
(
self
.
prefix
+
'_precision'
,
self
.
stat
.
precision
)
self
.
trainer
.
write
_scalar_summary
(
self
.
prefix
+
'_recall'
,
self
.
stat
.
recall
)
tensorpack/predict.py
View file @
ddf737d7
...
...
@@ -58,8 +58,7 @@ class PredictConfig(object):
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
None
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
...
...
@@ -87,7 +86,10 @@ def get_predict_func(config):
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
for
n
in
output_var_names
]
if
config
.
session_config
:
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
else
:
sess
=
tf
.
Session
()
config
.
session_init
.
init
(
sess
)
def
run_input
(
dp
):
...
...
@@ -116,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process):
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
else
:
logger
.
info
(
"Worker {} uses CPU"
.
format
(
self
.
idx
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'
0
'
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
(),
tf
.
device
(
'/gpu:0'
if
self
.
gpuid
>=
0
else
'/cpu:0'
):
if
self
.
idx
!=
0
:
...
...
tensorpack/train/base.py
View file @
ddf737d7
...
...
@@ -12,6 +12,7 @@ from .config import TrainConfig
from
..utils
import
*
from
..callbacks
import
StatHolder
from
..tfutils
import
*
from
..tfutils.summary
import
create_summary
from
..tfutils.modelutils
import
describe_model
__all__
=
[
'Trainer'
]
...
...
@@ -76,6 +77,12 @@ class Trainer(object):
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
def
write_scalar_summary
(
self
,
name
,
val
):
self
.
summary_writer
.
add_summary
(
create_summary
(
name
,
val
),
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
main_loop
(
self
):
# some final operations that might modify the graph
logger
.
info
(
"Preparing for training..."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment