Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4c7348c3
Commit
4c7348c3
authored
May 25, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
change how expreplay works...
parent
0c5e39eb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
183 additions
and
147 deletions
+183
-147
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+40
-32
tensorpack/dataflow/RL.py
tensorpack/dataflow/RL.py
+121
-38
tensorpack/dataflow/dataset/atari.py
tensorpack/dataflow/dataset/atari.py
+22
-77
No files found.
examples/Atari2600/DQN.py
View file @
4c7348c3
...
...
@@ -3,8 +3,8 @@
# File: DQN.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
import
tensorflow
as
tf
import
os
,
sys
,
re
import
random
import
argparse
...
...
@@ -22,7 +22,7 @@ from tensorpack.predict import PredictConfig, get_predict_func, ParallelPredictW
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow.dataset
import
Atari
Driver
,
Atari
Player
from
tensorpack.dataflow.dataset
import
AtariPlayer
from
tensorpack.dataflow.RL
import
ExpReplay
"""
...
...
@@ -36,13 +36,13 @@ IMAGE_SIZE = 84
NUM_ACTIONS
=
None
FRAME_HISTORY
=
4
ACTION_REPEAT
=
3
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
#
HEIGHT_RANGE = (28, -8) # for pong
#
HEIGHT_RANGE = (36, 204) # for breakout
HEIGHT_RANGE
=
(
28
,
-
8
)
# for pong
GAMMA
=
0.99
BATCH_SIZE
=
32
INIT_EXPLORATION
=
1
EXPLORATION_EPOCH_ANNEAL
=
0.002
5
EXPLORATION_EPOCH_ANNEAL
=
0.002
0
END_EXPLORATION
=
0.1
MEMORY_SIZE
=
1e6
...
...
@@ -62,15 +62,20 @@ class Model(ModelDesc):
def
_get_DQN_prediction
(
self
,
image
,
is_training
):
""" image: [0,255]"""
image
=
image
/
128.0
-
1
with
argscope
(
Conv2D
,
nl
=
tf
.
nn
.
relu
,
use_bias
=
True
):
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
2
)
l
=
Conv2D
(
'conv1'
,
l
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
2
)
l
=
Conv2D
(
'conv2'
,
l
,
out_channel
=
64
,
kernel_shape
=
4
,
stride
=
2
)
image
=
image
/
255.0
with
argscope
(
Conv2D
,
nl
=
PReLU
.
f
,
use_bias
=
True
):
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
1
)
l
=
MaxPooling
(
'pool0'
,
l
,
2
)
l
=
Conv2D
(
'conv1'
,
l
,
out_channel
=
32
,
kernel_shape
=
5
,
stride
=
1
)
l
=
MaxPooling
(
'pool1'
,
l
,
2
)
l
=
Conv2D
(
'conv2'
,
l
,
out_channel
=
64
,
kernel_shape
=
4
)
l
=
MaxPooling
(
'pool2'
,
l
,
2
)
l
=
Conv2D
(
'conv3'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
)
l
=
MaxPooling
(
'pool3'
,
l
,
2
)
l
=
Conv2D
(
'conv4'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
)
l
=
FullyConnected
(
'fc0'
,
l
,
512
)
l
=
FullyConnected
(
'fct'
,
l
,
out_dim
=
NUM_ACTIONS
,
nl
=
tf
.
identity
,
summary_activation
=
False
)
l
=
FullyConnected
(
'fc0'
,
l
,
512
,
nl
=
lambda
x
,
name
:
LeakyReLU
.
f
(
x
,
0.01
,
name
)
)
l
=
FullyConnected
(
'fct'
,
l
,
out_dim
=
NUM_ACTIONS
,
nl
=
tf
.
identity
)
return
l
def
_build_graph
(
self
,
inputs
,
is_training
):
...
...
@@ -136,14 +141,14 @@ def play_one_episode(player, func, verbose=False):
tot_reward
=
0
que
=
deque
(
maxlen
=
30
)
while
True
:
s
=
player
.
current_state
()
s
=
player
.
current_state
()
# XXX
outputs
=
func
([[
s
]])
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
if
verbose
:
print
action_value
,
act
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
player
.
driver
.
get_num_actions
()
))
act
=
random
.
choice
(
range
(
NUM_ACTIONS
))
if
len
(
que
)
==
que
.
maxlen
\
and
que
.
count
(
que
[
0
])
==
que
.
maxlen
:
act
=
1
# hack, avoid stuck
...
...
@@ -156,10 +161,11 @@ def play_one_episode(player, func, verbose=False):
return
tot_reward
def
play_model
(
model_path
,
romfile
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0.01
,
height_range
=
HEIGHT_RANGE
),
action_repeat
=
ACTION_REPEAT
)
player
=
HistoryFramePlayer
(
AtariPlayer
(
romfile
,
viz
=
0.01
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
),
FRAME_HISTORY
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driv
er
.
get_num_actions
()
NUM_ACTIONS
=
player
.
play
er
.
get_num_actions
()
M
=
Model
()
cfg
=
PredictConfig
(
...
...
@@ -186,10 +192,11 @@ def eval_model_multiprocess(model_path, romfile):
self
.
outq
=
outqueue
def
run
(
self
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0
,
height_range
=
HEIGHT_RANGE
),
action_repeat
=
ACTION_REPEAT
)
player
=
HistoryFramePlayer
(
AtariPlayer
(
romfile
,
viz
=
0
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
),
FRAME_HISTORY
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driv
er
.
get_num_actions
()
NUM_ACTIONS
=
player
.
play
er
.
get_num_actions
()
self
.
_init_runtime
()
while
True
:
score
=
play_one_episode
(
player
,
self
.
func
)
...
...
@@ -226,15 +233,15 @@ def get_config(romfile):
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
M
=
Model
()
driver
=
AtariDriver
(
romfile
,
height_range
=
HEIGHT_RANGE
)
player
=
AtariPlayer
(
romfile
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
driv
er
.
get_num_actions
()
NUM_ACTIONS
=
play
er
.
get_num_actions
()
dataset_train
=
ExpReplay
(
predictor
=
current_predictor
,
player
=
AtariPlayer
(
driver
,
hist_len
=
FRAME_HISTORY
,
action_repeat
=
ACTION_REPEAT
),
player
=
player
,
num_actions
=
NUM_ACTIONS
,
memory_size
=
MEMORY_SIZE
,
batch_size
=
BATCH_SIZE
,
...
...
@@ -242,22 +249,23 @@ def get_config(romfile):
exploration
=
INIT_EXPLORATION
,
end_exploration
=
END_EXPLORATION
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
reward_clip
=
(
-
1
,
2
))
reward_clip
=
(
-
1
,
1
),
history_len
=
FRAME_HISTORY
)
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
lr
=
tf
.
Variable
(
0.00
0
25
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
class
Evaluator
(
Callback
):
def
_trigger_epoch
(
self
):
logger
.
info
(
"Evaluating..."
)
output
=
subprocess
.
check_output
(
"""{} --task eval --rom {} --load {} 2>&1 | grep Average"""
.
format
(
"""
CUDA_VISIBLE_DEVICES=
{} --task eval --rom {} --load {} 2>&1 | grep Average"""
.
format
(
sys
.
argv
[
0
],
romfile
,
os
.
path
.
join
(
logger
.
LOG_DIR
,
'checkpoint'
)),
shell
=
True
)
output
=
output
.
strip
()
output
=
output
[
output
.
find
(
']'
)
+
1
:]
mean
,
maximum
=
re
.
findall
(
'[0-9
\
.
]+'
,
output
)
self
.
trainer
.
write_scalar_summary
(
'
eval_
mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'
eval_
max_score'
,
maximum
)
mean
,
maximum
=
re
.
findall
(
'[0-9
\
.
\
-]+'
,
output
)[
-
2
:]
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
maximum
)
return
TrainConfig
(
dataset
=
dataset_train
,
...
...
@@ -269,7 +277,7 @@ def get_config(romfile):
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
TargetNetworkUpdator
(
M
),
dataset_train
,
PeriodicCallback
(
Evaluator
(),
1
),
PeriodicCallback
(
Evaluator
(),
2
),
]),
session_config
=
get_default_sess_config
(
0.5
),
model
=
M
,
...
...
tensorpack/dataflow/RL.py
View file @
4c7348c3
...
...
@@ -19,10 +19,10 @@ from tensorpack.callbacks.base import Callback
Implement RL-related data preprocessing
"""
__all__
=
[
'ExpReplay'
,
'RLEnvironment'
,
'NaiveRLEnvironment'
]
__all__
=
[
'ExpReplay'
,
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'HistoryFramePlayer'
]
Experience
=
namedtuple
(
'Experience'
,
[
'state'
,
'action'
,
'reward'
,
'
next'
,
'
isOver'
])
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
class
RLEnvironment
(
object
):
__meta__
=
ABCMeta
...
...
@@ -65,6 +65,49 @@ class NaiveRLEnvironment(RLEnvironment):
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
class
ProxyPlayer
(
RLEnvironment
):
def
__init__
(
self
,
player
):
self
.
player
=
player
def
get_stat
(
self
):
return
self
.
player
.
get_stat
()
def
reset_stat
(
self
):
self
.
player
.
reset_stat
()
def
current_state
(
self
):
return
self
.
player
.
current_state
()
def
action
(
self
,
act
):
return
self
.
player
.
action
(
act
)
class
HistoryFramePlayer
(
ProxyPlayer
):
def
__init__
(
self
,
player
,
hist_len
):
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
s
=
self
.
player
.
current_state
()
self
.
history
.
append
(
s
)
def
current_state
(
self
):
assert
len
(
self
.
history
)
!=
0
diff_len
=
self
.
history
.
maxlen
-
len
(
self
.
history
)
if
diff_len
==
0
:
return
np
.
concatenate
(
self
.
history
,
axis
=
2
)
zeros
=
[
np
.
zeros_like
(
self
.
history
[
0
])
for
k
in
range
(
diff_len
)]
for
k
in
self
.
history
:
zeros
.
append
(
k
)
return
np
.
concatenate
(
zeros
,
axis
=
2
)
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
s
=
self
.
player
.
current_state
()
self
.
history
.
append
(
s
)
if
isOver
:
# s would be a new episode
self
.
history
.
clear
()
self
.
history
.
append
(
s
)
return
(
r
,
isOver
)
class
ExpReplay
(
DataFlow
,
Callback
):
"""
...
...
@@ -82,11 +125,15 @@ class ExpReplay(DataFlow, Callback):
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
reward_clip
=
None
,
new_experience_per_step
=
1
new_experience_per_step
=
1
,
history_len
=
1
):
"""
:param predictor: callabale. called with a state, return a distribution
:param predictor: a callabale calling the up-to-date network.
called with a state, return a distribution
:param player: a `RLEnvironment`
:param num_actions: int
:param history_len: length of history frames to concat. zero-filled initial frames
"""
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
...
...
@@ -106,51 +153,83 @@ class ExpReplay(DataFlow, Callback):
raise
RuntimeError
(
"Don't run me in multiple processes"
)
def
_populate_exp
(
self
):
p
=
self
.
rng
.
rand
()
old_s
=
self
.
player
.
current_state
()
if
p
<=
self
.
exploration
:
if
self
.
rng
.
rand
()
<=
self
.
exploration
:
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
act
=
np
.
argmax
(
self
.
predictor
(
old_s
))
# TODO race condition in session?
# build a history state
ss
=
[
old_s
]
for
k
in
range
(
1
,
self
.
history_len
):
hist_exp
=
self
.
mem
[
-
k
]
if
hist_exp
.
isOver
:
ss
.
append
(
np
.
zeros_like
(
ss
[
0
]))
else
:
ss
.
append
(
hist_exp
.
state
)
ss
=
np
.
concatenate
(
ss
,
axis
=
2
)
act
=
np
.
argmax
(
self
.
predictor
(
ss
))
reward
,
isOver
=
self
.
player
.
action
(
act
)
if
self
.
reward_clip
:
reward
=
np
.
clip
(
reward
,
self
.
reward_clip
[
0
],
self
.
reward_clip
[
1
])
s
=
self
.
player
.
current_state
()
#def view_state(state):
#""" for debug state representation"""
#r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
#print r.shape
#cv2.imshow("state", r)
#cv2.waitKey()
#print act, reward
#view_state(s)
# s is considered useless if isOver==True
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
s
,
isOver
))
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
get_data
(
self
):
# new s is considered useless if isOver==True
while
True
:
idxs
=
self
.
rng
.
randint
(
len
(
self
.
mem
),
size
=
self
.
batch_size
)
batch_exp
=
[
self
.
mem
[
k
]
for
k
in
idxs
]
batch_exp
=
[
self
.
sample_one
()
for
_
in
range
(
self
.
batch_size
)]
def
view_state
(
state
,
next_state
):
""" for debug state representation"""
r
=
np
.
concatenate
([
state
[:,:,
k
]
for
k
in
range
(
self
.
history_len
)],
axis
=
1
)
r2
=
np
.
concatenate
([
next_state
[:,:,
k
]
for
k
in
range
(
self
.
history_len
)],
axis
=
1
)
print
r
.
shape
r
=
np
.
concatenate
([
r
,
r2
],
axis
=
0
)
cv2
.
imshow
(
"state"
,
r
)
cv2
.
waitKey
()
exp
=
batch_exp
[
0
]
print
(
"Act: "
,
exp
[
3
],
" reward:"
,
exp
[
2
],
" isOver: "
,
exp
[
4
])
view_state
(
exp
[
0
],
exp
[
1
])
yield
self
.
_process_batch
(
batch_exp
)
for
_
in
range
(
self
.
new_experience_per_step
):
self
.
_populate_exp
()
def
sample_one
(
self
):
""" return the transition tuple for
[idx, idx+history_len] -> [idx+1, idx+1+history_len]
it's the transition from state idx+history_len-1 to state idx+history_len
"""
# look for a state to start with
# when x.isOver==True, (x+1).state is of a different episode
idx
=
self
.
rng
.
randint
(
len
(
self
.
mem
)
-
self
.
history_len
-
1
)
start_idx
=
idx
+
self
.
history_len
-
1
def
concat
(
idx
):
v
=
[
self
.
mem
[
x
]
.
state
for
x
in
range
(
idx
,
idx
+
self
.
history_len
)]
return
np
.
concatenate
(
v
,
axis
=
2
)
state
=
concat
(
idx
)
next_state
=
concat
(
idx
+
1
)
reward
=
self
.
mem
[
start_idx
]
.
reward
action
=
self
.
mem
[
start_idx
]
.
action
isOver
=
self
.
mem
[
start_idx
]
.
isOver
# zero-fill state before starting
zero_fill
=
False
for
k
in
range
(
1
,
self
.
history_len
):
if
self
.
mem
[
start_idx
-
k
]
.
isOver
:
zero_fill
=
True
if
zero_fill
:
state
[:,:,
-
k
-
1
]
=
0
if
k
+
2
<=
self
.
history_len
:
next_state
[:,:,
-
k
-
2
]
=
0
return
(
state
,
next_state
,
reward
,
action
,
isOver
)
def
_process_batch
(
self
,
batch_exp
):
state_shape
=
batch_exp
[
0
]
.
state
.
shape
state
=
np
.
zeros
((
self
.
batch_size
,
)
+
state_shape
,
dtype
=
'float32'
)
next_state
=
np
.
zeros
((
self
.
batch_size
,
)
+
state_shape
,
dtype
=
'float32'
)
reward
=
np
.
zeros
((
self
.
batch_size
,),
dtype
=
'float32'
)
action
=
np
.
zeros
((
self
.
batch_size
,),
dtype
=
'int32'
)
isOver
=
np
.
zeros
((
self
.
batch_size
,),
dtype
=
'bool'
)
for
idx
,
b
in
enumerate
(
batch_exp
):
state
[
idx
]
=
b
.
state
action
[
idx
]
=
b
.
action
next_state
[
idx
]
=
b
.
next
reward
[
idx
]
=
b
.
reward
isOver
[
idx
]
=
b
.
isOver
state
=
np
.
array
([
e
[
0
]
for
e
in
batch_exp
])
next_state
=
np
.
array
([
e
[
1
]
for
e
in
batch_exp
])
reward
=
np
.
array
([
e
[
2
]
for
e
in
batch_exp
])
action
=
np
.
array
([
e
[
3
]
for
e
in
batch_exp
])
isOver
=
np
.
array
([
e
[
4
]
for
e
in
batch_exp
])
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
# Callback-related:
...
...
@@ -170,12 +249,16 @@ class ExpReplay(DataFlow, Callback):
if
__name__
==
'__main__'
:
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
from
tensorpack.dataflow.dataset
import
AtariPlayer
import
sys
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
.
initialized
=
False
E
=
AtariExpReplay
(
predictor
,
predictor
,
AtariPlayer
(
AtariDriver
(
'../../space_invaders.bin'
,
viz
=
0.01
)),
populate_size
=
1000
)
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
20
)
E
=
ExpReplay
(
predictor
,
player
=
player
,
num_actions
=
player
.
get_num_actions
(),
populate_size
=
1001
,
history_len
=
4
)
E
.
init_memory
()
for
k
in
E
.
get_data
():
...
...
tensorpack/dataflow/dataset/atari.py
View file @
4c7348c3
...
...
@@ -9,6 +9,7 @@ import os
import
cv2
from
collections
import
deque
from
...utils
import
get_rng
,
logger
from
...utils.stat
import
StatCounter
from
..RL
import
RLEnvironment
try
:
...
...
@@ -16,23 +17,27 @@ try:
except
ImportError
:
logger
.
warn
(
"Cannot import ale_python_interface, Atari won't be available."
)
__all__
=
[
'Atari
Driver'
,
'Atari
Player'
]
__all__
=
[
'AtariPlayer'
]
class
Atari
Driv
er
(
RLEnvironment
):
class
Atari
Play
er
(
RLEnvironment
):
"""
A wrapper for atari emulator.
"""
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
)):
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
frame_skip
=
4
,
image_shape
=
(
84
,
84
)):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable
"""
super
(
AtariPlayer
,
self
)
.
__init__
()
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
self
.
rng
.
randint
(
0
,
1000
)))
self
.
ale
.
setInt
(
"frame_skip"
,
1
)
self
.
ale
.
setInt
(
"frame_skip"
,
frame_skip
)
self
.
ale
.
setBool
(
'color_averaging'
,
True
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
...
...
@@ -45,9 +50,11 @@ class AtariDriver(RLEnvironment):
if
self
.
viz
and
isinstance
(
self
.
viz
,
float
):
cv2
.
startWindowThread
()
cv2
.
namedWindow
(
self
.
romname
)
self
.
framenum
=
0
self
.
height_range
=
height_range
self
.
framenum
=
0
self
.
image_shape
=
image_shape
self
.
current_episode_score
=
StatCounter
()
self
.
_reset
()
...
...
@@ -61,9 +68,9 @@ class AtariDriver(RLEnvironment):
def
current_state
(
self
):
"""
:returns: a gray-scale
image, max-pooled over the last frame.
:returns: a gray-scale
(h, w, 1) image
"""
now
=
self
.
_grab_raw_image
()
ret
=
self
.
_grab_raw_image
()
if
self
.
viz
:
if
isinstance
(
self
.
viz
,
float
):
cv2
.
imshow
(
self
.
romname
,
ret
)
...
...
@@ -73,6 +80,8 @@ class AtariDriver(RLEnvironment):
self
.
framenum
+=
1
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_BGR2YUV
)[:,:,
0
]
ret
=
cv2
.
resize
(
ret
,
self
.
image_shape
)
ret
=
np
.
expand_dims
(
ret
,
axis
=
2
)
return
ret
def
get_num_actions
(
self
):
...
...
@@ -82,6 +91,7 @@ class AtariDriver(RLEnvironment):
return
len
(
self
.
actions
)
def
_reset
(
self
):
self
.
current_episode_score
.
reset
()
self
.
ale
.
reset_game
()
def
action
(
self
,
act
):
...
...
@@ -90,80 +100,13 @@ class AtariDriver(RLEnvironment):
:returns: (reward, isOver)
"""
r
=
self
.
ale
.
act
(
self
.
actions
[
act
])
self
.
current_episode_score
.
feed
(
r
)
isOver
=
self
.
ale
.
game_over
()
if
isOver
:
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
self
.
_reset
()
return
(
r
,
isOver
)
class
AtariPlayer
(
RLEnvironment
):
""" An Atari game player with limited memory and FPS"""
def
__init__
(
self
,
driver
,
hist_len
=
4
,
action_repeat
=
4
,
image_shape
=
(
84
,
84
)):
"""
:param driver: an `AtariDriver` instance.
:param hist_len: history(memory) length
:param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image
"""
super
(
AtariPlayer
,
self
)
.
__init__
()
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
self
.
last_act
=
0
self
.
frames
=
deque
(
maxlen
=
hist_len
)
self
.
current_accum_score
=
0
self
.
restart
()
def
restart
(
self
):
"""
Restart the game and populate frames with the beginning frame
"""
self
.
current_accum_score
=
0
self
.
frames
.
clear
()
s
=
self
.
driver
.
current_state
()
s
=
cv2
.
resize
(
s
,
self
.
image_shape
)
for
_
in
range
(
self
.
hist_len
):
self
.
frames
.
append
(
s
)
def
current_state
(
self
):
"""
Return a current state of shape `image_shape + (hist_len,)`
"""
return
self
.
_build_state
()
def
action
(
self
,
act
):
"""
Perform an action
:param act: index of the action
:returns: (reward, isOver)
"""
self
.
last_act
=
act
return
self
.
_observe
()
def
_build_state
(
self
):
assert
len
(
self
.
frames
)
==
self
.
hist_len
m
=
np
.
array
(
self
.
frames
)
m
=
m
.
transpose
([
1
,
2
,
0
])
return
m
def
_observe
(
self
):
""" if isOver==True, current_state will return the new episode
"""
totr
=
0
for
k
in
range
(
self
.
action_repeat
):
r
,
isOver
=
self
.
driver
.
action
(
self
.
last_act
)
s
=
self
.
driver
.
current_state
()
totr
+=
r
if
isOver
:
break
s
=
cv2
.
resize
(
s
,
self
.
image_shape
)
self
.
current_accum_score
+=
totr
self
.
frames
.
append
(
s
)
if
isOver
:
self
.
stats
[
'score'
]
.
append
(
self
.
current_accum_score
)
self
.
restart
()
return
(
totr
,
isOver
)
def
get_stat
(
self
):
try
:
return
{
'avg_score'
:
np
.
mean
(
self
.
stats
[
'score'
]),
...
...
@@ -173,7 +116,8 @@ class AtariPlayer(RLEnvironment):
if
__name__
==
'__main__'
:
import
sys
a
=
AtariDriver
(
sys
.
argv
[
1
],
viz
=
0.01
,
height_range
=
(
28
,
-
8
))
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.01
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_num_actions
()
rng
=
get_rng
(
num
)
import
time
...
...
@@ -183,6 +127,7 @@ if __name__ == '__main__':
act
=
rng
.
choice
(
range
(
num
))
print
act
r
,
o
=
a
.
action
(
act
)
a
.
current_state
()
#time.sleep(0.1)
print
(
r
,
o
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment