Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9319b978
Commit
9319b978
authored
May 21, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
height range for atari
parent
6e1f395d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
11 deletions
+15
-11
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+4
-3
tensorpack/dataflow/dataset/atari.py
tensorpack/dataflow/dataset/atari.py
+11
-8
No files found.
examples/Atari2600/DQN.py
View file @
9319b978
...
@@ -36,6 +36,7 @@ IMAGE_SIZE = 84
...
@@ -36,6 +36,7 @@ IMAGE_SIZE = 84
NUM_ACTIONS
=
None
NUM_ACTIONS
=
None
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
ACTION_REPEAT
=
3
ACTION_REPEAT
=
3
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
GAMMA
=
0.99
GAMMA
=
0.99
BATCH_SIZE
=
32
BATCH_SIZE
=
32
...
@@ -154,7 +155,7 @@ def play_one_episode(player, func, verbose=False):
...
@@ -154,7 +155,7 @@ def play_one_episode(player, func, verbose=False):
return
tot_reward
return
tot_reward
def
play_model
(
model_path
,
romfile
):
def
play_model
(
model_path
,
romfile
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0.01
),
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0.01
,
height_range
=
HEIGHT_RANGE
),
action_repeat
=
ACTION_REPEAT
)
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
...
@@ -184,7 +185,7 @@ def eval_model_multiprocess(model_path, romfile):
...
@@ -184,7 +185,7 @@ def eval_model_multiprocess(model_path, romfile):
self
.
outq
=
outqueue
self
.
outq
=
outqueue
def
run
(
self
):
def
run
(
self
):
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0
),
player
=
AtariPlayer
(
AtariDriver
(
romfile
,
viz
=
0
,
height_range
=
HEIGHT_RANGE
),
action_repeat
=
ACTION_REPEAT
)
action_repeat
=
ACTION_REPEAT
)
global
NUM_ACTIONS
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
NUM_ACTIONS
=
player
.
driver
.
get_num_actions
()
...
@@ -224,7 +225,7 @@ def get_config(romfile):
...
@@ -224,7 +225,7 @@ def get_config(romfile):
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
M
=
Model
()
M
=
Model
()
driver
=
AtariDriver
(
romfile
)
driver
=
AtariDriver
(
romfile
,
height_range
=
HEIGHT_RANGE
)
global
NUM_ACTIONS
global
NUM_ACTIONS
NUM_ACTIONS
=
driver
.
get_num_actions
()
NUM_ACTIONS
=
driver
.
get_num_actions
()
...
...
tensorpack/dataflow/dataset/atari.py
View file @
9319b978
...
@@ -22,7 +22,8 @@ class AtariDriver(object):
...
@@ -22,7 +22,8 @@ class AtariDriver(object):
"""
"""
A wrapper for atari emulator.
A wrapper for atari emulator.
"""
"""
def
__init__
(
self
,
rom_file
,
frame_skip
=
1
,
viz
=
0
):
def
__init__
(
self
,
rom_file
,
frame_skip
=
1
,
viz
=
0
,
height_range
=
(
None
,
None
)):
"""
"""
:param rom_file: path to the rom
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param frame_skip: skip every k frames
...
@@ -48,6 +49,7 @@ class AtariDriver(object):
...
@@ -48,6 +49,7 @@ class AtariDriver(object):
self
.
_reset
()
self
.
_reset
()
self
.
last_image
=
self
.
_grab_raw_image
()
self
.
last_image
=
self
.
_grab_raw_image
()
self
.
framenum
=
0
self
.
framenum
=
0
self
.
height_range
=
height_range
def
_grab_raw_image
(
self
):
def
_grab_raw_image
(
self
):
"""
"""
...
@@ -64,14 +66,15 @@ class AtariDriver(object):
...
@@ -64,14 +66,15 @@ class AtariDriver(object):
now
=
self
.
_grab_raw_image
()
now
=
self
.
_grab_raw_image
()
ret
=
np
.
maximum
(
now
,
self
.
last_image
)
ret
=
np
.
maximum
(
now
,
self
.
last_image
)
self
.
last_image
=
now
self
.
last_image
=
now
if
self
.
viz
and
isinstance
(
self
.
viz
,
float
):
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
# several online repos all use this
cv2
.
imshow
(
self
.
romname
,
ret
)
if
self
.
viz
:
time
.
sleep
(
self
.
viz
)
if
isinstance
(
self
.
viz
,
float
):
elif
self
.
viz
:
cv2
.
imshow
(
self
.
romname
,
ret
)
cv2
.
imwrite
(
"{}/{:06d}.jpg"
.
format
(
self
.
viz
,
self
.
framenum
),
ret
)
time
.
sleep
(
self
.
viz
)
self
.
framenum
+=
1
else
:
cv2
.
imwrite
(
"{}/{:06d}.jpg"
.
format
(
self
.
viz
,
self
.
framenum
),
ret
)
self
.
framenum
+=
1
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_BGR2YUV
)[:,:,
0
]
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_BGR2YUV
)[:,:,
0
]
ret
=
ret
[
36
:
204
,:]
# several online repos all use this
return
ret
return
ret
def
get_num_actions
(
self
):
def
get_num_actions
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment