Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
943b1701
Commit
943b1701
authored
Jun 06, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
separate common from DQN
parent
76cbc245
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
180 additions
and
131 deletions
+180
-131
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+14
-101
examples/Atari2600/common.py
examples/Atari2600/common.py
+103
-0
examples/cifar-convnet.py
examples/cifar-convnet.py
+0
-1
examples/svhn-digit-convnet.py
examples/svhn-digit-convnet.py
+0
-1
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+6
-8
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+34
-4
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+23
-16
No files found.
examples/Atari2600/DQN.py
View file @
943b1701
...
@@ -13,20 +13,16 @@ import subprocess
...
@@ -13,20 +13,16 @@ import subprocess
import
multiprocessing
,
threading
import
multiprocessing
,
threading
from
collections
import
deque
from
collections
import
deque
from
six.moves
import
queue
from
tqdm
import
tqdm
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.concurrency
import
(
ensure_proc_terminate
,
\
from
tensorpack.utils.concurrency
import
*
subproc_call
,
StoppableThread
)
from
tensorpack.utils.stat
import
*
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
MultiProcessPredictWorker
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.callbacks
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.RL
import
*
from
tensorpack.RL
import
*
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
BATCH_SIZE
=
32
BATCH_SIZE
=
32
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
...
@@ -56,16 +52,18 @@ def get_player(viz=False, train=False):
...
@@ -56,16 +52,18 @@ def get_player(viz=False, train=False):
frame_skip
=
ACTION_REPEAT
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
frame_skip
=
ACTION_REPEAT
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
live_lost_as_eoe
=
train
)
global
NUM_ACTIONS
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_num_actions
()
NUM_ACTIONS
=
pl
.
get_
action_space
()
.
num_actions
()
if
not
train
:
if
not
train
:
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
20000
)
pl
=
LimitLengthPlayer
(
pl
,
20000
)
return
pl
return
pl
common
.
get_player
=
get_player
# so that eval functions in common can use the player
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
assert
NUM_ACTIONS
is
not
None
if
NUM_ACTIONS
is
None
:
p
=
get_player
();
del
p
return
[
InputVar
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
return
[
InputVar
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
InputVar
(
tf
.
int64
,
(
None
,),
'action'
),
InputVar
(
tf
.
int64
,
(
None
,),
'action'
),
InputVar
(
tf
.
float32
,
(
None
,),
'reward'
),
InputVar
(
tf
.
float32
,
(
None
,),
'reward'
),
...
@@ -141,85 +139,6 @@ class Model(ModelDesc):
...
@@ -141,85 +139,6 @@ class Model(ModelDesc):
def
predictor
(
self
,
state
):
def
predictor
(
self
,
state
):
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
def
f
(
s
):
act
=
func
([[
s
]])[
0
][
0
]
.
argmax
()
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
NUM_ACTIONS
))
if
verbose
:
print
(
act
)
return
act
return
np
.
mean
(
player
.
play_one_episode
(
f
))
def
play_model
(
model_path
):
player
=
get_player
(
viz
=
0.01
)
cfg
=
PredictConfig
(
model
=
Model
(),
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
predfunc
=
get_predict_func
(
cfg
)
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
def
eval_with_funcs
(
predict_funcs
,
nr_eval
=
EVAL_EPISODE
):
class
Worker
(
StoppableThread
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
self
.
func
=
func
self
.
q
=
queue
def
run
(
self
):
player
=
get_player
()
while
not
self
.
stopped
():
score
=
play_one_episode
(
player
,
self
.
func
)
self
.
queue_put_stoppable
(
self
.
q
,
score
)
q
=
queue
.
Queue
(
maxsize
=
2
)
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict_funcs
]
for
k
in
threads
:
k
.
start
()
time
.
sleep
(
0.1
)
# avoid simulator bugs
stat
=
StatCounter
()
try
:
for
_
in
tqdm
(
range
(
nr_eval
)):
r
=
q
.
get
()
stat
.
feed
(
r
)
finally
:
logger
.
info
(
"Waiting for all the workers to finish the last run..."
)
for
k
in
threads
:
k
.
stop
()
for
k
in
threads
:
k
.
join
()
return
(
stat
.
average
,
stat
.
max
)
def
eval_model_multithread
(
model_path
):
cfg
=
PredictConfig
(
model
=
Model
(),
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
p
=
get_player
();
del
p
# set NUM_ACTIONS
func
=
get_predict_func
(
cfg
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
class
Evaluator
(
Callback
):
def
_before_train
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
[
'state'
],
[
'fct/output'
])]
*
NR_PROC
self
.
eval_episode
=
EVAL_EPISODE
def
_trigger_epoch
(
self
):
t
=
time
.
time
()
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
,
nr_eval
=
self
.
eval_episode
)
t
=
time
.
time
()
-
t
if
t
>
8
*
60
:
# eval takes too long
self
.
eval_episode
=
int
(
self
.
eval_episode
*
0.89
)
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
max
)
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
logger
.
set_logger_dir
(
...
@@ -229,10 +148,9 @@ def get_config():
...
@@ -229,10 +148,9 @@ def get_config():
dataset_train
=
ExpReplay
(
dataset_train
=
ExpReplay
(
predictor
=
M
.
predictor
,
predictor
=
M
.
predictor
,
player
=
get_player
(
train
=
True
),
player
=
get_player
(
train
=
True
),
num_actions
=
NUM_ACTIONS
,
memory_size
=
MEMORY_SIZE
,
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
populate_size
=
INIT_MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
exploration
=
INIT_EXPLORATION
,
exploration
=
INIT_EXPLORATION
,
end_exploration
=
END_EXPLORATION
,
end_exploration
=
END_EXPLORATION
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
...
@@ -253,7 +171,7 @@ def get_config():
...
@@ -253,7 +171,7 @@ def get_config():
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
RunOp
(
lambda
:
M
.
update_target_param
()),
RunOp
(
lambda
:
M
.
update_target_param
()),
dataset_train
,
dataset_train
,
PeriodicCallback
(
Evaluator
(),
2
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
),
2
),
]),
]),
# save memory for multiprocess evaluator
# save memory for multiprocess evaluator
session_config
=
get_default_sess_config
(
0.3
),
session_config
=
get_default_sess_config
(
0.3
),
...
@@ -272,20 +190,15 @@ if __name__ == '__main__':
...
@@ -272,20 +190,15 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
if
args
.
task
!=
'train'
:
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
assert
args
.
load
is
not
None
ROM_FILE
=
args
.
rom
ROM_FILE
=
args
.
rom
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
play_model
(
args
.
load
)
play_model
(
Model
(),
args
.
load
)
sys
.
exit
()
elif
args
.
task
==
'eval'
:
if
args
.
task
==
'eval'
:
eval_model_multithread
(
Model
(),
args
.
load
,
EVAL_EPISODE
)
eval_model_multithread
(
args
.
load
)
else
:
sys
.
exit
()
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
...
...
examples/Atari2600/common.py
0 → 100644
View file @
943b1701
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
random
,
time
import
threading
,
multiprocessing
import
numpy
as
np
from
tqdm
import
tqdm
from
six.moves
import
queue
from
tensorpack
import
*
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
MultiProcessPredictWorker
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.stat
import
*
from
tensorpack.callbacks
import
*
global
get_player
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
# 0.01-greedy evaluation
def
f
(
s
):
spc
=
player
.
get_action_space
()
act
=
func
([[
s
]])[
0
][
0
]
.
argmax
()
if
random
.
random
()
<
0.01
:
act
=
spc
.
sample
()
if
verbose
:
print
(
act
)
return
act
return
np
.
mean
(
player
.
play_one_episode
(
f
))
def
play_model
(
M
,
model_path
):
player
=
get_player
(
viz
=
0.01
)
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
predfunc
=
get_predict_func
(
cfg
)
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
def
eval_with_funcs
(
predict_funcs
,
nr_eval
):
class
Worker
(
StoppableThread
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
self
.
func
=
func
self
.
q
=
queue
def
run
(
self
):
player
=
get_player
()
while
not
self
.
stopped
():
score
=
play_one_episode
(
player
,
self
.
func
)
self
.
queue_put_stoppable
(
self
.
q
,
score
)
q
=
queue
.
Queue
(
maxsize
=
2
)
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict_funcs
]
for
k
in
threads
:
k
.
start
()
time
.
sleep
(
0.1
)
# avoid simulator bugs
stat
=
StatCounter
()
try
:
for
_
in
tqdm
(
range
(
nr_eval
)):
r
=
q
.
get
()
stat
.
feed
(
r
)
except
:
logger
.
exception
(
"Eval"
)
finally
:
logger
.
info
(
"Waiting for all the workers to finish the last run..."
)
for
k
in
threads
:
k
.
stop
()
for
k
in
threads
:
k
.
join
()
if
stat
.
count
>
0
:
return
(
stat
.
average
,
stat
.
max
)
return
(
0
,
0
)
def
eval_model_multithread
(
M
,
model_path
,
nr_eval
):
cfg
=
PredictConfig
(
model
=
M
,
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
func
=
get_predict_func
(
cfg
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
class
Evaluator
(
Callback
):
def
__init__
(
self
,
nr_eval
):
self
.
eval_episode
=
nr_eval
def
_before_train
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
[
'state'
],
[
'fct/output'
])]
*
NR_PROC
def
_trigger_epoch
(
self
):
t
=
time
.
time
()
mean
,
max
=
eval_with_funcs
(
self
.
pred_funcs
,
nr_eval
=
self
.
eval_episode
)
t
=
time
.
time
()
-
t
if
t
>
8
*
60
:
# eval takes too long
self
.
eval_episode
=
int
(
self
.
eval_episode
*
0.89
)
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
max
)
examples/cifar-convnet.py
View file @
943b1701
...
@@ -52,7 +52,6 @@ class Model(ModelDesc):
...
@@ -52,7 +52,6 @@ class Model(ModelDesc):
l
=
tf
.
nn
.
dropout
(
l
,
keep_prob
)
l
=
tf
.
nn
.
dropout
(
l
,
keep_prob
)
l
=
FullyConnected
(
'fc1'
,
l
,
512
,
l
=
FullyConnected
(
'fc1'
,
l
,
512
,
b_init
=
tf
.
constant_initializer
(
0.1
))
b_init
=
tf
.
constant_initializer
(
0.1
))
# fc will have activation summary by default. disable for the output layer
logits
=
FullyConnected
(
'linear'
,
l
,
out_dim
=
self
.
cifar_classnum
,
nl
=
tf
.
identity
)
logits
=
FullyConnected
(
'linear'
,
l
,
out_dim
=
self
.
cifar_classnum
,
nl
=
tf
.
identity
)
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
,
label
)
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
,
label
)
...
...
examples/svhn-digit-convnet.py
View file @
943b1701
...
@@ -44,7 +44,6 @@ class Model(ModelDesc):
...
@@ -44,7 +44,6 @@ class Model(ModelDesc):
l
=
tf
.
nn
.
dropout
(
l
,
keep_prob
)
l
=
tf
.
nn
.
dropout
(
l
,
keep_prob
)
l
=
FullyConnected
(
'fc0'
,
l
,
512
,
l
=
FullyConnected
(
'fc0'
,
l
,
512
,
b_init
=
tf
.
constant_initializer
(
0.1
))
b_init
=
tf
.
constant_initializer
(
0.1
))
# fc will have activation summary by default. disable for the output layer
logits
=
FullyConnected
(
'linear'
,
l
,
out_dim
=
10
,
nl
=
tf
.
identity
)
logits
=
FullyConnected
(
'linear'
,
l
,
out_dim
=
10
,
nl
=
tf
.
identity
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'output'
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'output'
)
...
...
tensorpack/RL/atari.py
View file @
943b1701
...
@@ -12,7 +12,7 @@ from six.moves import range
...
@@ -12,7 +12,7 @@ from six.moves import range
from
..utils
import
get_rng
,
logger
,
memoized
from
..utils
import
get_rng
,
logger
,
memoized
from
..utils.stat
import
StatCounter
from
..utils.stat
import
StatCounter
from
.envbase
import
RLEnvironment
from
.envbase
import
RLEnvironment
,
DiscreteActionSpace
try
:
try
:
from
ale_python_interface
import
ALEInterface
from
ale_python_interface
import
ALEInterface
...
@@ -104,6 +104,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -104,6 +104,7 @@ class AtariPlayer(RLEnvironment):
ret
=
np
.
maximum
(
ret
,
self
.
last_raw_screen
)
ret
=
np
.
maximum
(
ret
,
self
.
last_raw_screen
)
if
self
.
viz
:
if
self
.
viz
:
if
isinstance
(
self
.
viz
,
float
):
if
isinstance
(
self
.
viz
,
float
):
#m = cv2.resize(ret, (1920,1200))
cv2
.
imshow
(
self
.
windowname
,
ret
)
cv2
.
imshow
(
self
.
windowname
,
ret
)
time
.
sleep
(
self
.
viz
)
time
.
sleep
(
self
.
viz
)
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
...
@@ -113,11 +114,8 @@ class AtariPlayer(RLEnvironment):
...
@@ -113,11 +114,8 @@ class AtariPlayer(RLEnvironment):
ret
=
np
.
expand_dims
(
ret
,
axis
=
2
)
ret
=
np
.
expand_dims
(
ret
,
axis
=
2
)
return
ret
return
ret
def
get_num_actions
(
self
):
def
get_action_space
(
self
):
"""
return
DiscreteActionSpace
(
len
(
self
.
actions
))
:returns: the number of legal actions
"""
return
len
(
self
.
actions
)
def
restart_episode
(
self
):
def
restart_episode
(
self
):
if
self
.
current_episode_score
.
count
>
0
:
if
self
.
current_episode_score
.
count
>
0
:
...
@@ -170,7 +168,7 @@ if __name__ == '__main__':
...
@@ -170,7 +168,7 @@ if __name__ == '__main__':
def
benchmark
():
def
benchmark
():
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
False
,
height_range
=
(
28
,
-
8
))
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
False
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_num_actions
()
num
=
a
.
get_
action_space
()
.
num_actions
()
rng
=
get_rng
(
num
)
rng
=
get_rng
(
num
)
start
=
time
.
time
()
start
=
time
.
time
()
cnt
=
0
cnt
=
0
...
@@ -194,7 +192,7 @@ if __name__ == '__main__':
...
@@ -194,7 +192,7 @@ if __name__ == '__main__':
else
:
else
:
a
=
AtariPlayer
(
sys
.
argv
[
1
],
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.03
,
height_range
=
(
28
,
-
8
))
viz
=
0.03
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_num_actions
()
num
=
a
.
get_
action_space
()
.
num_actions
()
rng
=
get_rng
(
num
)
rng
=
get_rng
(
num
)
import
time
import
time
while
True
:
while
True
:
...
...
tensorpack/RL/envbase.py
View file @
943b1701
...
@@ -6,8 +6,11 @@
...
@@ -6,8 +6,11 @@
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
from
collections
import
defaultdict
import
random
from
..utils
import
get_rng
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
]
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
'DiscreteActionSpace'
]
class
RLEnvironment
(
object
):
class
RLEnvironment
(
object
):
__meta__
=
ABCMeta
__meta__
=
ABCMeta
...
@@ -33,6 +36,10 @@ class RLEnvironment(object):
...
@@ -33,6 +36,10 @@ class RLEnvironment(object):
""" Start a new episode, even if the current hasn't ended """
""" Start a new episode, even if the current hasn't ended """
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_action_space
(
self
):
""" return an `ActionSpace` instance"""
raise
NotImplementedError
()
def
get_stat
(
self
):
def
get_stat
(
self
):
"""
"""
return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
...
@@ -40,7 +47,7 @@ class RLEnvironment(object):
...
@@ -40,7 +47,7 @@ class RLEnvironment(object):
return
{}
return
{}
def
reset_stat
(
self
):
def
reset_stat
(
self
):
""" reset
the
statistics counter"""
""" reset
all
statistics counter"""
self
.
stats
=
defaultdict
(
list
)
self
.
stats
=
defaultdict
(
list
)
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
...
@@ -57,6 +64,28 @@ class RLEnvironment(object):
...
@@ -57,6 +64,28 @@ class RLEnvironment(object):
self
.
reset_stat
()
self
.
reset_stat
()
return
s
return
s
class
ActionSpace
(
object
):
def
__init__
(
self
):
self
.
rng
=
get_rng
(
self
)
@
abstractmethod
def
sample
(
self
):
pass
def
num_actions
(
self
):
raise
NotImplementedError
()
class
DiscreteActionSpace
(
ActionSpace
):
def
__init__
(
self
,
num
):
super
(
DiscreteActionSpace
,
self
)
.
__init__
()
self
.
num
=
num
def
sample
(
self
):
return
self
.
rng
.
randint
(
self
.
num
)
def
num_actions
(
self
):
return
self
.
num
class
NaiveRLEnvironment
(
RLEnvironment
):
class
NaiveRLEnvironment
(
RLEnvironment
):
""" for testing only"""
""" for testing only"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -67,8 +96,6 @@ class NaiveRLEnvironment(RLEnvironment):
...
@@ -67,8 +96,6 @@ class NaiveRLEnvironment(RLEnvironment):
def
action
(
self
,
act
):
def
action
(
self
,
act
):
self
.
k
=
act
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
return
(
self
.
k
,
self
.
k
>
10
)
def
restart_episode
(
self
):
pass
class
ProxyPlayer
(
RLEnvironment
):
class
ProxyPlayer
(
RLEnvironment
):
""" Serve as a proxy another player """
""" Serve as a proxy another player """
...
@@ -93,3 +120,6 @@ class ProxyPlayer(RLEnvironment):
...
@@ -93,3 +120,6 @@ class ProxyPlayer(RLEnvironment):
def
restart_episode
(
self
):
def
restart_episode
(
self
):
self
.
player
.
restart_episode
()
self
.
player
.
restart_episode
()
def
get_action_space
(
self
):
return
self
.
player
.
get_action_space
()
tensorpack/RL/expreplay.py
View file @
943b1701
...
@@ -25,10 +25,10 @@ class ExpReplay(DataFlow, Callback):
...
@@ -25,10 +25,10 @@ class ExpReplay(DataFlow, Callback):
def
__init__
(
self
,
def
__init__
(
self
,
predictor
,
predictor
,
player
,
player
,
num_actions
,
memory_size
=
1e6
,
batch_size
=
32
,
batch_size
=
32
,
populate_size
=
50000
,
memory_size
=
1e6
,
populate_size
=
None
,
# deprecated
init_memory_size
=
50000
,
exploration
=
1
,
exploration
=
1
,
end_exploration
=
0.1
,
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
exploration_epoch_anneal
=
0.002
,
...
@@ -37,20 +37,27 @@ class ExpReplay(DataFlow, Callback):
...
@@ -37,20 +37,27 @@ class ExpReplay(DataFlow, Callback):
history_len
=
1
history_len
=
1
):
):
"""
"""
:param predictor: a callabale calling the up-to-date network.
:param predictor: a callabale running the up-to-date network.
called with a state, return a distribution
called with a state, return a distribution.
:param player: a `RLEnvironment`
:param player: an `RLEnvironment`
:param num_actions: int
:param history_len: length of history frames to concat. zero-filled initial frames
:param history_len: length of history frames to concat. zero-filled initial frames
:param update_frequency: number of new transitions to add to memory
after sampling a batch of transitions for training
"""
"""
# XXX back-compat
if
populate_size
is
not
None
:
logger
.
warn
(
"populate_size in ExpReplay is deprecated in favor of init_memory_size"
)
init_memory_size
=
populate_size
for
k
,
v
in
locals
()
.
items
():
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
self
.
num_actions
=
player
.
get_action_space
()
.
num_actions
()
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
mem
=
deque
(
maxlen
=
memory_size
)
self
.
mem
=
deque
(
maxlen
=
memory_size
)
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
def
init_memory
(
self
):
def
_
init_memory
(
self
):
logger
.
info
(
"Populating replay memory..."
)
logger
.
info
(
"Populating replay memory..."
)
# fill some for the history
# fill some for the history
...
@@ -60,8 +67,8 @@ class ExpReplay(DataFlow, Callback):
...
@@ -60,8 +67,8 @@ class ExpReplay(DataFlow, Callback):
self
.
_populate_exp
()
self
.
_populate_exp
()
self
.
exploration
=
old_exploration
self
.
exploration
=
old_exploration
with
tqdm
(
total
=
self
.
populate
_size
)
as
pbar
:
with
tqdm
(
total
=
self
.
init_memory
_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
self
.
populate
_size
:
while
len
(
self
.
mem
)
<
self
.
init_memory
_size
:
self
.
_populate_exp
()
self
.
_populate_exp
()
pbar
.
update
()
pbar
.
update
()
...
@@ -96,7 +103,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -96,7 +103,7 @@ class ExpReplay(DataFlow, Callback):
def
get_data
(
self
):
def
get_data
(
self
):
# new s is considered useless if isOver==True
# new s is considered useless if isOver==True
while
True
:
while
True
:
batch_exp
=
[
self
.
sample_one
()
for
_
in
range
(
self
.
batch_size
)]
batch_exp
=
[
self
.
_
sample_one
()
for
_
in
range
(
self
.
batch_size
)]
#import cv2
#import cv2
#def view_state(state, next_state):
#def view_state(state, next_state):
...
@@ -116,7 +123,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -116,7 +123,7 @@ class ExpReplay(DataFlow, Callback):
for
_
in
range
(
self
.
update_frequency
):
for
_
in
range
(
self
.
update_frequency
):
self
.
_populate_exp
()
self
.
_populate_exp
()
def
sample_one
(
self
):
def
_
sample_one
(
self
):
""" return the transition tuple for
""" return the transition tuple for
[idx, idx+history_len] -> [idx+1, idx+1+history_len]
[idx, idx+history_len] -> [idx+1, idx+1+history_len]
it's the transition from state idx+history_len-1 to state idx+history_len
it's the transition from state idx+history_len-1 to state idx+history_len
...
@@ -155,14 +162,14 @@ class ExpReplay(DataFlow, Callback):
...
@@ -155,14 +162,14 @@ class ExpReplay(DataFlow, Callback):
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
# Callback-related:
# Callback-related:
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
init_memory
()
self
.
_
init_memory
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
exploration
>
self
.
end_exploration
:
if
self
.
exploration
>
self
.
end_exploration
:
self
.
exploration
-=
self
.
exploration_epoch_anneal
self
.
exploration
-=
self
.
exploration_epoch_anneal
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
exploration
))
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
exploration
))
# log player statistics
stats
=
self
.
player
.
get_stat
()
stats
=
self
.
player
.
get_stat
()
for
k
,
v
in
six
.
iteritems
(
stats
):
for
k
,
v
in
six
.
iteritems
(
stats
):
if
isinstance
(
v
,
float
):
if
isinstance
(
v
,
float
):
...
@@ -177,10 +184,10 @@ if __name__ == '__main__':
...
@@ -177,10 +184,10 @@ if __name__ == '__main__':
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
E
=
ExpReplay
(
predictor
,
E
=
ExpReplay
(
predictor
,
player
=
player
,
player
=
player
,
num_actions
=
player
.
get_num_actions
(),
num_actions
=
player
.
get_
action_space
()
.
num_actions
(),
populate_size
=
1001
,
populate_size
=
1001
,
history_len
=
4
)
history_len
=
4
)
E
.
init_memory
()
E
.
_
init_memory
()
for
k
in
E
.
get_data
():
for
k
in
E
.
get_data
():
import
IPython
as
IP
;
import
IPython
as
IP
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment