Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b61d0722
Commit
b61d0722
authored
Jul 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
finish_episode in RLenv
parent
300b2c3a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
19 deletions
+17
-19
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+6
-3
tensorpack/RL/common.py
tensorpack/RL/common.py
+4
-5
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+7
-0
tensorpack/RL/history.py
tensorpack/RL/history.py
+0
-11
No files found.
tensorpack/RL/atari.py
View file @
b61d0722
...
@@ -32,7 +32,8 @@ _ALE_LOCK = threading.Lock()
...
@@ -32,7 +32,8 @@ _ALE_LOCK = threading.Lock()
class
AtariPlayer
(
RLEnvironment
):
class
AtariPlayer
(
RLEnvironment
):
"""
"""
A wrapper for atari emulator.
A wrapper for atari emulator.
NOTE: will automatically restart when a real episode ends
Will automatically restart when a real episode ends (isOver might be just
lost of lives but not game over).
"""
"""
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
frame_skip
=
4
,
image_shape
=
(
84
,
84
),
nullop_start
=
30
,
frame_skip
=
4
,
image_shape
=
(
84
,
84
),
nullop_start
=
30
,
...
@@ -129,9 +130,10 @@ class AtariPlayer(RLEnvironment):
...
@@ -129,9 +130,10 @@ class AtariPlayer(RLEnvironment):
def
get_action_space
(
self
):
def
get_action_space
(
self
):
return
DiscreteActionSpace
(
len
(
self
.
actions
))
return
DiscreteActionSpace
(
len
(
self
.
actions
))
def
restart_episode
(
self
):
def
finish_episode
(
self
):
if
self
.
current_episode_score
.
count
>
0
:
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
def
restart_episode
(
self
):
self
.
current_episode_score
.
reset
()
self
.
current_episode_score
.
reset
()
self
.
ale
.
reset_game
()
self
.
ale
.
reset_game
()
...
@@ -162,6 +164,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -162,6 +164,7 @@ class AtariPlayer(RLEnvironment):
self
.
current_episode_score
.
feed
(
r
)
self
.
current_episode_score
.
feed
(
r
)
isOver
=
self
.
ale
.
game_over
()
isOver
=
self
.
ale
.
game_over
()
if
isOver
:
if
isOver
:
self
.
finish_episode
()
self
.
restart_episode
()
self
.
restart_episode
()
if
self
.
live_lost_as_eoe
:
if
self
.
live_lost_as_eoe
:
isOver
=
isOver
or
newlives
<
oldlives
isOver
=
isOver
or
newlives
<
oldlives
...
...
tensorpack/RL/common.py
View file @
b61d0722
...
@@ -40,7 +40,9 @@ class PreventStuckPlayer(ProxyPlayer):
...
@@ -40,7 +40,9 @@ class PreventStuckPlayer(ProxyPlayer):
self
.
act_que
.
clear
()
self
.
act_que
.
clear
()
class
LimitLengthPlayer
(
ProxyPlayer
):
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode"""
""" Limit the total number of actions in an episode.
Does not auto restart.
"""
def
__init__
(
self
,
player
,
limit
):
def
__init__
(
self
,
player
,
limit
):
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
self
.
limit
=
limit
self
.
limit
=
limit
...
@@ -51,10 +53,6 @@ class LimitLengthPlayer(ProxyPlayer):
...
@@ -51,10 +53,6 @@ class LimitLengthPlayer(ProxyPlayer):
self
.
cnt
+=
1
self
.
cnt
+=
1
if
self
.
cnt
>=
self
.
limit
:
if
self
.
cnt
>=
self
.
limit
:
isOver
=
True
isOver
=
True
self
.
player
.
restart_episode
()
if
isOver
:
#print self.cnt, self.player.stats # to see what limit is appropriate
self
.
cnt
=
0
return
(
r
,
isOver
)
return
(
r
,
isOver
)
def
restart_episode
(
self
):
def
restart_episode
(
self
):
...
@@ -67,6 +65,7 @@ class AutoRestartPlayer(ProxyPlayer):
...
@@ -67,6 +65,7 @@ class AutoRestartPlayer(ProxyPlayer):
def
action
(
self
,
act
):
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
r
,
isOver
=
self
.
player
.
action
(
act
)
if
isOver
:
if
isOver
:
self
.
player
.
finish_episode
()
self
.
player
.
restart_episode
()
self
.
player
.
restart_episode
()
return
r
,
isOver
return
r
,
isOver
...
...
tensorpack/RL/envbase.py
View file @
b61d0722
...
@@ -36,6 +36,10 @@ class RLEnvironment(object):
...
@@ -36,6 +36,10 @@ class RLEnvironment(object):
""" Start a new episode, even if the current hasn't ended """
""" Start a new episode, even if the current hasn't ended """
raise
NotImplementedError
()
raise
NotImplementedError
()
def
finish_episode
(
self
):
""" get called when an episode finished"""
pass
def
get_action_space
(
self
):
def
get_action_space
(
self
):
""" return an `ActionSpace` instance"""
""" return an `ActionSpace` instance"""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -112,5 +116,8 @@ class ProxyPlayer(RLEnvironment):
...
@@ -112,5 +116,8 @@ class ProxyPlayer(RLEnvironment):
def
restart_episode
(
self
):
def
restart_episode
(
self
):
self
.
player
.
restart_episode
()
self
.
player
.
restart_episode
()
def
finish_episode
(
self
):
self
.
player
.
finish_episode
()
def
get_action_space
(
self
):
def
get_action_space
(
self
):
return
self
.
player
.
get_action_space
()
return
self
.
player
.
get_action_space
()
tensorpack/RL/history.py
View file @
b61d0722
...
@@ -49,14 +49,3 @@ class HistoryFramePlayer(ProxyPlayer):
...
@@ -49,14 +49,3 @@ class HistoryFramePlayer(ProxyPlayer):
self
.
history
.
clear
()
self
.
history
.
clear
()
self
.
history
.
append
(
self
.
player
.
current_state
())
self
.
history
.
append
(
self
.
player
.
current_state
())
class
TimePointHistoryFramePlayer
(
HistoryFramePlayer
):
""" Include history from a list of time points in the past"""
def
__init__
(
self
,
player
,
hists
):
""" hists: a list of positive integers. 1 means the last frame"""
queue_size
=
max
(
hists
)
+
1
super
(
TimePointHistoryFramePlayer
,
self
)
.
__init__
(
player
,
queue_size
)
self
.
hists
=
hists
def
current_state
(
self
):
# TODO
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment