Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
64707bfa
Commit
64707bfa
authored
May 29, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better atari env
parent
80722088
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
13 deletions
+34
-13
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+34
-13
No files found.
tensorpack/RL/atari.py
View file @
64707bfa
...
@@ -8,6 +8,7 @@ import time
...
@@ -8,6 +8,7 @@ import time
import
os
import
os
import
cv2
import
cv2
from
collections
import
deque
from
collections
import
deque
from
six.moves
import
range
from
..utils
import
get_rng
,
logger
from
..utils
import
get_rng
,
logger
from
..utils.stat
import
StatCounter
from
..utils.stat
import
StatCounter
...
@@ -25,7 +26,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -25,7 +26,7 @@ class AtariPlayer(RLEnvironment):
A wrapper for atari emulator.
A wrapper for atari emulator.
"""
"""
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
frame_skip
=
4
,
image_shape
=
(
84
,
84
)):
frame_skip
=
4
,
image_shape
=
(
84
,
84
)
,
nullop_start
=
30
):
"""
"""
:param rom_file: path to the rom
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param frame_skip: skip every k frames
...
@@ -37,9 +38,12 @@ class AtariPlayer(RLEnvironment):
...
@@ -37,9 +38,12 @@ class AtariPlayer(RLEnvironment):
self
.
ale
=
ALEInterface
()
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
0
,
1000
))
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setInt
(
"frame_skip"
,
frame_skip
)
self
.
ale
.
setInt
(
"frame_skip"
,
1
)
self
.
ale
.
setBool
(
'color_averaging'
,
True
)
self
.
ale
.
setBool
(
'color_averaging'
,
False
)
# manual.pdf suggests otherwise. may need to check
self
.
ale
.
setFloat
(
'repeat_action_probability'
,
0.0
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
...
@@ -51,8 +55,9 @@ class AtariPlayer(RLEnvironment):
...
@@ -51,8 +55,9 @@ class AtariPlayer(RLEnvironment):
if
self
.
viz
and
isinstance
(
self
.
viz
,
float
):
if
self
.
viz
and
isinstance
(
self
.
viz
,
float
):
cv2
.
startWindowThread
()
cv2
.
startWindowThread
()
cv2
.
namedWindow
(
self
.
romname
)
cv2
.
namedWindow
(
self
.
romname
)
self
.
framenum
=
0
self
.
frame_skip
=
frame_skip
self
.
nullop_start
=
nullop_start
self
.
height_range
=
height_range
self
.
height_range
=
height_range
self
.
image_shape
=
image_shape
self
.
image_shape
=
image_shape
self
.
current_episode_score
=
StatCounter
()
self
.
current_episode_score
=
StatCounter
()
...
@@ -63,8 +68,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -63,8 +68,7 @@ class AtariPlayer(RLEnvironment):
"""
"""
:returns: the current 3-channel image
:returns: the current 3-channel image
"""
"""
m
=
np
.
zeros
(
self
.
height
*
self
.
width
*
3
,
dtype
=
np
.
uint8
)
m
=
self
.
ale
.
getScreenRGB
()
self
.
ale
.
getScreenRGB
(
m
)
return
m
.
reshape
((
self
.
height
,
self
.
width
,
3
))
return
m
.
reshape
((
self
.
height
,
self
.
width
,
3
))
def
current_state
(
self
):
def
current_state
(
self
):
...
@@ -72,15 +76,15 @@ class AtariPlayer(RLEnvironment):
...
@@ -72,15 +76,15 @@ class AtariPlayer(RLEnvironment):
:returns: a gray-scale (h, w, 1) image
:returns: a gray-scale (h, w, 1) image
"""
"""
ret
=
self
.
_grab_raw_image
()
ret
=
self
.
_grab_raw_image
()
# max-pooled over the last screen
ret
=
np
.
maximum
(
ret
,
self
.
last_raw_screen
)
if
self
.
viz
:
if
self
.
viz
:
if
isinstance
(
self
.
viz
,
float
):
if
isinstance
(
self
.
viz
,
float
):
cv2
.
imshow
(
self
.
romname
,
ret
)
cv2
.
imshow
(
self
.
romname
,
ret
)
time
.
sleep
(
self
.
viz
)
time
.
sleep
(
self
.
viz
)
else
:
cv2
.
imwrite
(
"{}/{:06d}.jpg"
.
format
(
self
.
viz
,
self
.
framenum
),
ret
)
self
.
framenum
+=
1
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_BGR2YUV
)[:,:,
0
]
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
ret
=
cv2
.
resize
(
ret
,
self
.
image_shape
)
ret
=
cv2
.
resize
(
ret
,
self
.
image_shape
)
ret
=
np
.
expand_dims
(
ret
,
axis
=
2
)
ret
=
np
.
expand_dims
(
ret
,
axis
=
2
)
return
ret
return
ret
...
@@ -95,17 +99,34 @@ class AtariPlayer(RLEnvironment):
...
@@ -95,17 +99,34 @@ class AtariPlayer(RLEnvironment):
self
.
current_episode_score
.
reset
()
self
.
current_episode_score
.
reset
()
self
.
ale
.
reset_game
()
self
.
ale
.
reset_game
()
# random null-ops start
n
=
self
.
rng
.
randint
(
self
.
nullop_start
)
for
k
in
range
(
n
):
if
k
==
n
-
1
:
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
self
.
ale
.
act
(
0
)
def
action
(
self
,
act
):
def
action
(
self
,
act
):
"""
"""
:param act: an index of the action
:param act: an index of the action
:returns: (reward, isOver)
:returns: (reward, isOver)
"""
"""
r
=
self
.
ale
.
act
(
self
.
actions
[
act
])
oldlives
=
self
.
ale
.
lives
()
r
=
0
for
k
in
range
(
self
.
frame_skip
):
if
k
==
self
.
frame_skip
-
1
:
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
r
+=
self
.
ale
.
act
(
self
.
actions
[
act
])
newlives
=
self
.
ale
.
lives
()
if
self
.
ale
.
game_over
()
or
newlives
<
oldlives
:
break
self
.
current_episode_score
.
feed
(
r
)
self
.
current_episode_score
.
feed
(
r
)
isOver
=
self
.
ale
.
game_over
()
isOver
=
self
.
ale
.
game_over
()
if
isOver
:
if
isOver
:
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
self
.
_reset
()
self
.
_reset
()
isOver
=
isOver
or
newlives
<
oldlives
return
(
r
,
isOver
)
return
(
r
,
isOver
)
def
get_stat
(
self
):
def
get_stat
(
self
):
...
@@ -118,7 +139,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -118,7 +139,7 @@ class AtariPlayer(RLEnvironment):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
sys
import
sys
a
=
AtariPlayer
(
sys
.
argv
[
1
],
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.0
1
,
height_range
=
(
28
,
-
8
))
viz
=
0.0
3
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_num_actions
()
num
=
a
.
get_num_actions
()
rng
=
get_rng
(
num
)
rng
=
get_rng
(
num
)
import
time
import
time
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment