Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5fd47e6d
Commit
5fd47e6d
authored
May 29, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
don't use eoe in eval
parent
64707bfa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
6 deletions
+18
-6
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+5
-3
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+12
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+1
-0
No files found.
examples/Atari2600/DQN.py
View file @
5fd47e6d
...
@@ -50,8 +50,10 @@ EVAL_EPISODE = 100
...
@@ -50,8 +50,10 @@ EVAL_EPISODE = 100
NUM_ACTIONS
=
None
NUM_ACTIONS
=
None
ROM_FILE
=
None
ROM_FILE
=
None
def
get_player
(
viz
=
False
):
def
get_player
(
viz
=
False
,
train
=
False
):
pl
=
AtariPlayer
(
ROM_FILE
,
viz
=
viz
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
)
player
=
AtariPlayer
(
ROM_FILE
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
global
NUM_ACTIONS
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_num_actions
()
NUM_ACTIONS
=
pl
.
get_num_actions
()
return
pl
return
pl
...
@@ -220,7 +222,7 @@ def get_config():
...
@@ -220,7 +222,7 @@ def get_config():
M
=
Model
()
M
=
Model
()
dataset_train
=
ExpReplay
(
dataset_train
=
ExpReplay
(
predictor
=
current_predictor
,
predictor
=
current_predictor
,
player
=
get_player
(),
player
=
get_player
(
train
=
True
),
num_actions
=
NUM_ACTIONS
,
num_actions
=
NUM_ACTIONS
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
...
...
tensorpack/RL/atari.py
View file @
5fd47e6d
...
@@ -26,13 +26,16 @@ class AtariPlayer(RLEnvironment):
...
@@ -26,13 +26,16 @@ class AtariPlayer(RLEnvironment):
A wrapper for atari emulator.
A wrapper for atari emulator.
"""
"""
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
frame_skip
=
4
,
image_shape
=
(
84
,
84
),
nullop_start
=
30
):
frame_skip
=
4
,
image_shape
=
(
84
,
84
),
nullop_start
=
30
,
live_lost_as_eoe
=
True
):
"""
"""
:param rom_file: path to the rom
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param frame_skip: skip every k frames
:param image_shape: (w, h)
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable
:param viz: the delay. visualize the game while running. 0 to disable
:param nullop_start: start with random number of null ops
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
"""
"""
super
(
AtariPlayer
,
self
)
.
__init__
()
super
(
AtariPlayer
,
self
)
.
__init__
()
self
.
ale
=
ALEInterface
()
self
.
ale
=
ALEInterface
()
...
@@ -45,6 +48,8 @@ class AtariPlayer(RLEnvironment):
...
@@ -45,6 +48,8 @@ class AtariPlayer(RLEnvironment):
self
.
ale
.
setFloat
(
'repeat_action_probability'
,
0.0
)
self
.
ale
.
setFloat
(
'repeat_action_probability'
,
0.0
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
...
@@ -56,6 +61,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -56,6 +61,7 @@ class AtariPlayer(RLEnvironment):
cv2
.
startWindowThread
()
cv2
.
startWindowThread
()
cv2
.
namedWindow
(
self
.
romname
)
cv2
.
namedWindow
(
self
.
romname
)
self
.
live_lost_as_eoe
=
live_lost_as_eoe
self
.
frame_skip
=
frame_skip
self
.
frame_skip
=
frame_skip
self
.
nullop_start
=
nullop_start
self
.
nullop_start
=
nullop_start
self
.
height_range
=
height_range
self
.
height_range
=
height_range
...
@@ -101,6 +107,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -101,6 +107,7 @@ class AtariPlayer(RLEnvironment):
# random null-ops start
# random null-ops start
n
=
self
.
rng
.
randint
(
self
.
nullop_start
)
n
=
self
.
rng
.
randint
(
self
.
nullop_start
)
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
for
k
in
range
(
n
):
for
k
in
range
(
n
):
if
k
==
n
-
1
:
if
k
==
n
-
1
:
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
...
@@ -118,7 +125,8 @@ class AtariPlayer(RLEnvironment):
...
@@ -118,7 +125,8 @@ class AtariPlayer(RLEnvironment):
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
r
+=
self
.
ale
.
act
(
self
.
actions
[
act
])
r
+=
self
.
ale
.
act
(
self
.
actions
[
act
])
newlives
=
self
.
ale
.
lives
()
newlives
=
self
.
ale
.
lives
()
if
self
.
ale
.
game_over
()
or
newlives
<
oldlives
:
if
self
.
ale
.
game_over
()
or
\
(
self
.
live_lost_as_eoe
and
newlives
<
oldlives
):
break
break
self
.
current_episode_score
.
feed
(
r
)
self
.
current_episode_score
.
feed
(
r
)
...
@@ -126,7 +134,8 @@ class AtariPlayer(RLEnvironment):
...
@@ -126,7 +134,8 @@ class AtariPlayer(RLEnvironment):
if
isOver
:
if
isOver
:
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
self
.
_reset
()
self
.
_reset
()
isOver
=
isOver
or
newlives
<
oldlives
if
self
.
live_lost_as_eoe
:
isOver
=
isOver
or
newlives
<
oldlives
return
(
r
,
isOver
)
return
(
r
,
isOver
)
def
get_stat
(
self
):
def
get_stat
(
self
):
...
...
tensorpack/predict/concurrency.py
View file @
5fd47e6d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
multiprocessing
,
threading
import
multiprocessing
,
threading
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
from
six.moves
import
queue
,
range
from
six.moves
import
queue
,
range
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment