Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1555899d
Commit
1555899d
authored
Jun 06, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
more general atari/common
parent
943b1701
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
22 deletions
+15
-22
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+10
-5
examples/Atari2600/common.py
examples/Atari2600/common.py
+5
-14
tensorpack/predict/common.py
tensorpack/predict/common.py
+0
-3
No files found.
examples/Atari2600/DQN.py
View file @
1555899d
...
@@ -171,7 +171,7 @@ def get_config():
...
@@ -171,7 +171,7 @@ def get_config():
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
RunOp
(
lambda
:
M
.
update_target_param
()),
RunOp
(
lambda
:
M
.
update_target_param
()),
dataset_train
,
dataset_train
,
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
),
2
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
'fct/output:0'
),
2
),
]),
]),
# save memory for multiprocess evaluator
# save memory for multiprocess evaluator
session_config
=
get_default_sess_config
(
0.3
),
session_config
=
get_default_sess_config
(
0.3
),
...
@@ -194,10 +194,15 @@ if __name__ == '__main__':
...
@@ -194,10 +194,15 @@ if __name__ == '__main__':
assert
args
.
load
is
not
None
assert
args
.
load
is
not
None
ROM_FILE
=
args
.
rom
ROM_FILE
=
args
.
rom
if
args
.
task
!=
'train'
:
cfg
=
PredictConfig
(
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
output_var_names
=
[
'fct/output:0'
])
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
play_model
(
Model
(),
args
.
load
)
play_model
(
cfg
)
elif
args
.
task
==
'eval'
:
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
Model
(),
args
.
load
,
EVAL_EPISODE
)
eval_model_multithread
(
cfg
,
EVAL_EPISODE
)
else
:
else
:
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
...
...
examples/Atari2600/common.py
View file @
1555899d
...
@@ -28,13 +28,8 @@ def play_one_episode(player, func, verbose=False):
...
@@ -28,13 +28,8 @@ def play_one_episode(player, func, verbose=False):
return
act
return
act
return
np
.
mean
(
player
.
play_one_episode
(
f
))
return
np
.
mean
(
player
.
play_one_episode
(
f
))
def
play_model
(
M
,
model_path
):
def
play_model
(
cfg
):
player
=
get_player
(
viz
=
0.01
)
player
=
get_player
(
viz
=
0.01
)
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
predfunc
=
get_predict_func
(
cfg
)
predfunc
=
get_predict_func
(
cfg
)
while
True
:
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
score
=
play_one_episode
(
player
,
predfunc
)
...
@@ -73,25 +68,21 @@ def eval_with_funcs(predict_funcs, nr_eval):
...
@@ -73,25 +68,21 @@ def eval_with_funcs(predict_funcs, nr_eval):
return
(
stat
.
average
,
stat
.
max
)
return
(
stat
.
average
,
stat
.
max
)
return
(
0
,
0
)
return
(
0
,
0
)
def
eval_model_multithread
(
M
,
model_path
,
nr_eval
):
def
eval_model_multithread
(
cfg
,
nr_eval
):
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
func
=
get_predict_func
(
cfg
)
func
=
get_predict_func
(
cfg
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
class
Evaluator
(
Callback
):
class
Evaluator
(
Callback
):
def
__init__
(
self
,
nr_eval
):
def
__init__
(
self
,
nr_eval
,
output_name
):
self
.
eval_episode
=
nr_eval
self
.
eval_episode
=
nr_eval
self
.
output_name
=
output_name
def
_before_train
(
self
):
def
_before_train
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
[
'state'
],
[
'fct/output'
])]
*
NR_PROC
[
'state'
],
[
self
.
output_name
])]
*
NR_PROC
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
t
=
time
.
time
()
t
=
time
.
time
()
...
...
tensorpack/predict/common.py
View file @
1555899d
...
@@ -84,9 +84,6 @@ def get_predict_func(config):
...
@@ -84,9 +84,6 @@ def get_predict_func(config):
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
def
run_input
(
dp
):
def
run_input
(
dp
):
assert
len
(
input_map
)
==
len
(
dp
),
\
"Graph has {} inputs but dataset only gives {} components!"
.
format
(
len
(
input_map
),
len
(
dp
))
feed
=
dict
(
zip
(
input_map
,
dp
))
feed
=
dict
(
zip
(
input_map
,
dp
))
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
run_input
return
run_input
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment