Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ddf737d7
Commit
ddf737d7
authored
May 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better infra for evaluate
parent
961b0ee4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
64 deletions
+74
-64
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+57
-48
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+4
-12
tensorpack/predict.py
tensorpack/predict.py
+6
-4
tensorpack/train/base.py
tensorpack/train/base.py
+7
-0
No files found.
examples/Atari2600/DQN.py
View file @
ddf737d7
...
@@ -5,10 +5,11 @@
...
@@ -5,10 +5,11 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
import
numpy
as
np
import
os
,
sys
import
os
,
sys
,
re
import
random
import
random
import
argparse
import
argparse
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
subprocess
import
multiprocessing
import
multiprocessing
from
collections
import
deque
from
collections
import
deque
...
@@ -22,7 +23,7 @@ from tensorpack.tfutils import symbolic_functions as symbf
...
@@ -22,7 +23,7 @@ from tensorpack.tfutils import symbolic_functions as symbf
from
tensorpack.callbacks
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
from
exp_replay
import
Atari
ExpReplay
from
tensorpack.dataflow.RL
import
ExpReplay
"""
"""
Implement DQN in:
Implement DQN in:
...
@@ -44,6 +45,8 @@ END_EXPLORATION = 0.1
...
@@ -44,6 +45,8 @@ END_EXPLORATION = 0.1
INIT_MEMORY_SIZE
=
50000
INIT_MEMORY_SIZE
=
50000
MEMORY_SIZE
=
1e6
MEMORY_SIZE
=
1e6
STEP_PER_EPOCH
=
10000
EVAL_EPISODE
=
100
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
...
@@ -138,41 +141,47 @@ class ExpReplayController(Callback):
...
@@ -138,41 +141,47 @@ class ExpReplayController(Callback):
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
d
.
exploration
>
END_EXPLORATION
:
if
self
.
d
.
exploration
>
END_EXPLORATION
:
self
.
d
.
exploration
-=
EXPLORATION_EPOCH_ANNEAL
self
.
d
.
exploration
-=
EXPLORATION_EPOCH_ANNEAL
logger
.
info
(
"Exploration
:
{}"
.
format
(
self
.
d
.
exploration
))
logger
.
info
(
"Exploration
changed to
{}"
.
format
(
self
.
d
.
exploration
))
def
play_model
(
model_path
,
romfile
):
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0.01
),
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
M
=
Model
()
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
predfunc
=
get_predict_func
(
cfg
)
tot_reward
=
0
tot_reward
=
0
que
=
deque
(
maxlen
=
30
)
que
=
deque
(
maxlen
=
30
)
while
True
:
while
True
:
s
=
player
.
current_state
()
s
=
player
.
current_state
()
outputs
=
pred
func
([[
s
]])
outputs
=
func
([[
s
]])
action_value
=
outputs
[
0
][
0
]
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
act
=
action_value
.
argmax
()
if
verbose
:
print
action_value
,
act
print
action_value
,
act
if
random
.
random
()
<
0.01
:
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
player
.
driver
.
get_num_actions
()))
act
=
random
.
choice
(
range
(
player
.
driver
.
get_num_actions
()))
if
len
(
que
)
==
que
.
maxlen
\
if
len
(
que
)
==
que
.
maxlen
\
and
que
.
count
(
que
[
0
])
==
que
.
maxlen
:
and
que
.
count
(
que
[
0
])
==
que
.
maxlen
:
act
=
1
act
=
1
# hack, avoid stuck
que
.
append
(
act
)
que
.
append
(
act
)
if
verbose
:
print
(
act
)
print
(
act
)
reward
,
isOver
=
player
.
action
(
act
)
reward
,
isOver
=
player
.
action
(
act
)
tot_reward
+=
reward
tot_reward
+=
reward
if
isOver
:
if
isOver
:
print
(
"Total:"
,
tot_reward
)
return
tot_reward
tot_reward
=
0
def
play_model
(
model_path
,
romfile
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0.01
),
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
M
=
Model
()
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
predfunc
=
get_predict_func
(
cfg
)
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
def
eval_model_multiprocess
(
model_path
,
romfile
):
def
eval_model_multiprocess
(
model_path
,
romfile
):
M
=
Model
()
M
=
Model
()
...
@@ -192,29 +201,10 @@ def eval_model_multiprocess(model_path, romfile):
...
@@ -192,29 +201,10 @@ def eval_model_multiprocess(model_path, romfile):
action_repeat
=
ACTION_REPEAT
)
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
self
.
_init_runtime
()
self
.
_init_runtime
()
tot_reward
=
0
que
=
deque
(
maxlen
=
30
)
while
True
:
while
True
:
s
=
player
.
current_state
()
score
=
play_one_episode
(
player
,
self
.
func
)
outputs
=
self
.
func
([[
s
]])
self
.
outq
.
put
(
score
)
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
#print action_value, act
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
player
.
driver
.
get_num_actions
()))
if
len
(
que
)
==
que
.
maxlen
\
and
que
.
count
(
que
[
0
])
==
que
.
maxlen
:
act
=
1
que
.
append
(
act
)
#print(act)
reward
,
isOver
=
player
.
action
(
act
)
tot_reward
+=
reward
if
isOver
:
self
.
outq
.
put
(
tot_reward
)
tot_reward
=
0
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
10
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
10
)
procs
=
[]
procs
=
[]
...
@@ -226,13 +216,19 @@ def eval_model_multiprocess(model_path, romfile):
...
@@ -226,13 +216,19 @@ def eval_model_multiprocess(model_path, romfile):
k
.
start
()
k
.
start
()
stat
=
StatCounter
()
stat
=
StatCounter
()
try
:
try
:
EVAL_EPISODE
=
50
for
_
in
tqdm
(
range
(
EVAL_EPISODE
)):
for
_
in
tqdm
(
range
(
EVAL_EPISODE
)):
r
=
q
.
get
()
r
=
q
.
get
()
stat
.
feed
(
r
)
stat
.
feed
(
r
)
finally
:
finally
:
logger
.
info
(
"Average Score: {}. Max Score: {}"
.
format
(
for
p
in
procs
:
p
.
terminate
()
p
.
join
()
if
stat
.
count
()
>
0
:
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
stat
.
average
,
stat
.
max
))
stat
.
average
,
stat
.
max
))
return
(
stat
.
average
,
stat
.
max
)
else
:
return
(
0
,
0
)
def
get_config
(
romfile
):
def
get_config
(
romfile
):
...
@@ -260,6 +256,18 @@ def get_config(romfile):
...
@@ -260,6 +256,18 @@ def get_config(romfile):
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
class
Evaluator
(
Callback
):
def
_trigger_epoch
(
self
):
logger
.
info
(
"Evaluating..."
)
output
=
subprocess
.
check_output
(
"""{} --task eval --rom {} --load {} 2>&1 | grep Average"""
.
format
(
sys
.
argv
[
0
],
romfile
,
os
.
path
.
join
(
logger
.
LOG_DIR
,
'checkpoint'
)),
shell
=
True
)
output
=
output
.
strip
()
output
=
output
[
output
.
find
(
']'
)
+
1
:]
mean
,
maximum
=
re
.
findall
(
'[0-9
\
.]+'
,
output
)
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
maximum
)
return
TrainConfig
(
return
TrainConfig
(
dataset
=
dataset_train
,
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
...
@@ -269,11 +277,12 @@ def get_config(romfile):
...
@@ -269,11 +277,12 @@ def get_config(romfile):
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
TargetNetworkUpdator
(
M
),
TargetNetworkUpdator
(
M
),
ExpReplayController
(
dataset_train
)
ExpReplayController
(
dataset_train
),
PeriodicCallback
(
Evaluator
(),
1
),
]),
]),
session_config
=
get_default_sess_config
(
0.5
),
session_config
=
get_default_sess_config
(
0.5
),
model
=
M
,
model
=
M
,
step_per_epoch
=
10000
,
step_per_epoch
=
STEP_PER_EPOCH
,
max_epoch
=
10000
,
max_epoch
=
10000
,
)
)
...
...
tensorpack/callbacks/inference.py
View file @
ddf737d7
...
@@ -59,12 +59,6 @@ class Inferencer(object):
...
@@ -59,12 +59,6 @@ class Inferencer(object):
def
_get_output_tensors
(
self
):
def
_get_output_tensors
(
self
):
pass
pass
def
_scalar_summary
(
self
,
name
,
val
):
self
.
trainer
.
summary_writer
.
add_summary
(
create_summary
(
name
,
val
),
get_global_step
())
self
.
trainer
.
stat_holder
.
add_stat
(
name
,
val
)
class
InferenceRunner
(
Callback
):
class
InferenceRunner
(
Callback
):
"""
"""
A callback that runs different kinds of inferencer.
A callback that runs different kinds of inferencer.
...
@@ -161,9 +155,7 @@ class ScalarStats(Inferencer):
...
@@ -161,9 +155,7 @@ class ScalarStats(Inferencer):
for
stat
,
name
in
zip
(
self
.
stats
,
self
.
names
):
for
stat
,
name
in
zip
(
self
.
stats
,
self
.
names
):
opname
,
_
=
get_op_var_name
(
name
)
opname
,
_
=
get_op_var_name
(
name
)
name
=
'{}_{}'
.
format
(
self
.
prefix
,
opname
)
if
self
.
prefix
else
opname
name
=
'{}_{}'
.
format
(
self
.
prefix
,
opname
)
if
self
.
prefix
else
opname
self
.
trainer
.
summary_writer
.
add_summary
(
self
.
trainer
.
write_scalar_summary
(
name
,
stat
)
create_summary
(
name
,
stat
),
get_global_step
())
self
.
trainer
.
stat_holder
.
add_stat
(
name
,
stat
)
class
ClassificationError
(
Inferencer
):
class
ClassificationError
(
Inferencer
):
"""
"""
...
@@ -197,7 +189,7 @@ class ClassificationError(Inferencer):
...
@@ -197,7 +189,7 @@ class ClassificationError(Inferencer):
self
.
err_stat
.
feed
(
wrong
,
batch_size
)
self
.
err_stat
.
feed
(
wrong
,
batch_size
)
def
_after_inference
(
self
):
def
_after_inference
(
self
):
self
.
_scalar_summary
(
self
.
summary_name
,
self
.
err_stat
.
accuracy
)
self
.
trainer
.
write
_scalar_summary
(
self
.
summary_name
,
self
.
err_stat
.
accuracy
)
class
BinaryClassificationStats
(
Inferencer
):
class
BinaryClassificationStats
(
Inferencer
):
...
@@ -221,5 +213,5 @@ class BinaryClassificationStats(Inferencer):
...
@@ -221,5 +213,5 @@ class BinaryClassificationStats(Inferencer):
self
.
stat
.
feed
(
pred
,
label
)
self
.
stat
.
feed
(
pred
,
label
)
def
_after_inference
(
self
):
def
_after_inference
(
self
):
self
.
_scalar_summary
(
self
.
prefix
+
'_precision'
,
self
.
stat
.
precision
)
self
.
trainer
.
write
_scalar_summary
(
self
.
prefix
+
'_precision'
,
self
.
stat
.
precision
)
self
.
_scalar_summary
(
self
.
prefix
+
'_recall'
,
self
.
stat
.
recall
)
self
.
trainer
.
write
_scalar_summary
(
self
.
prefix
+
'_recall'
,
self
.
stat
.
recall
)
tensorpack/predict.py
View file @
ddf737d7
...
@@ -58,8 +58,7 @@ class PredictConfig(object):
...
@@ -58,8 +58,7 @@ class PredictConfig(object):
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
None
)
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
...
@@ -87,7 +86,10 @@ def get_predict_func(config):
...
@@ -87,7 +86,10 @@ def get_predict_func(config):
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
for
n
in
output_var_names
]
for
n
in
output_var_names
]
if
config
.
session_config
:
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
else
:
sess
=
tf
.
Session
()
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
def
run_input
(
dp
):
def
run_input
(
dp
):
...
@@ -116,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process):
...
@@ -116,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process):
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
else
:
else
:
logger
.
info
(
"Worker {} uses CPU"
.
format
(
self
.
idx
))
logger
.
info
(
"Worker {} uses CPU"
.
format
(
self
.
idx
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'
0
'
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
(),
tf
.
device
(
'/gpu:0'
if
self
.
gpuid
>=
0
else
'/cpu:0'
):
with
G
.
as_default
(),
tf
.
device
(
'/gpu:0'
if
self
.
gpuid
>=
0
else
'/cpu:0'
):
if
self
.
idx
!=
0
:
if
self
.
idx
!=
0
:
...
...
tensorpack/train/base.py
View file @
ddf737d7
...
@@ -12,6 +12,7 @@ from .config import TrainConfig
...
@@ -12,6 +12,7 @@ from .config import TrainConfig
from
..utils
import
*
from
..utils
import
*
from
..callbacks
import
StatHolder
from
..callbacks
import
StatHolder
from
..tfutils
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
create_summary
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
__all__
=
[
'Trainer'
]
__all__
=
[
'Trainer'
]
...
@@ -76,6 +77,12 @@ class Trainer(object):
...
@@ -76,6 +77,12 @@ class Trainer(object):
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
def
write_scalar_summary
(
self
,
name
,
val
):
self
.
summary_writer
.
add_summary
(
create_summary
(
name
,
val
),
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
main_loop
(
self
):
def
main_loop
(
self
):
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Preparing for training..."
)
logger
.
info
(
"Preparing for training..."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment