Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dee9d398
Commit
dee9d398
authored
Jun 04, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
restart_episode in RLENV
parent
4e472eb5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
39 additions
and
17 deletions
+39
-17
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-1
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+7
-8
tensorpack/RL/common.py
tensorpack/RL/common.py
+17
-2
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+10
-5
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-1
tensorpack/utils/stat.py
tensorpack/utils/stat.py
+3
-0
No files found.
examples/Atari2600/DQN.py
View file @
dee9d398
...
...
@@ -155,7 +155,7 @@ def play_one_episode(player, func, verbose=False):
if
verbose
:
print
(
act
)
return
act
return
player
.
play_one_episode
(
f
)
return
np
.
mean
(
player
.
play_one_episode
(
f
)
)
def
play_model
(
model_path
):
player
=
get_player
(
0.013
)
...
...
tensorpack/RL/atari.py
View file @
dee9d398
...
...
@@ -34,7 +34,7 @@ class AtariPlayer(RLEnvironment):
live_lost_as_eoe
=
True
):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param frame_skip: skip every k frames
and repeat the action
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable
...
...
@@ -57,10 +57,8 @@ class AtariPlayer(RLEnvironment):
self
.
ale
.
setBool
(
'color_averaging'
,
False
)
# manual.pdf suggests otherwise. may need to check
self
.
ale
.
setFloat
(
'repeat_action_probability'
,
0.0
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
...
...
@@ -77,9 +75,9 @@ class AtariPlayer(RLEnvironment):
self
.
nullop_start
=
nullop_start
self
.
height_range
=
height_range
self
.
image_shape
=
image_shape
self
.
current_episode_score
=
StatCounter
()
self
.
_reset
()
self
.
current_episode_score
=
StatCounter
()
self
.
restart_episode
()
def
_grab_raw_image
(
self
):
"""
...
...
@@ -112,7 +110,9 @@ class AtariPlayer(RLEnvironment):
"""
return
len
(
self
.
actions
)
def
_reset
(
self
):
def
restart_episode
(
self
):
if
self
.
current_episode_score
.
count
>
0
:
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
self
.
current_episode_score
.
reset
()
self
.
ale
.
reset_game
()
...
...
@@ -143,8 +143,7 @@ class AtariPlayer(RLEnvironment):
self
.
current_episode_score
.
feed
(
r
)
isOver
=
self
.
ale
.
game_over
()
if
isOver
:
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
self
.
_reset
()
self
.
restart_episode
()
if
self
.
live_lost_as_eoe
:
isOver
=
isOver
or
newlives
<
oldlives
return
(
r
,
isOver
)
...
...
tensorpack/RL/common.py
View file @
dee9d398
...
...
@@ -39,6 +39,11 @@ class HistoryFramePlayer(ProxyPlayer):
self
.
history
.
append
(
s
)
return
(
r
,
isOver
)
def
restart_episode
(
self
):
super
(
HistoryFramePlayer
,
self
)
.
restart_episode
()
self
.
history
.
clear
()
self
.
history
.
append
(
self
.
player
.
current_state
())
class
PreventStuckPlayer
(
ProxyPlayer
):
""" Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout
...
...
@@ -63,6 +68,10 @@ class PreventStuckPlayer(ProxyPlayer):
self
.
act_que
.
clear
()
return
(
r
,
isOver
)
def
restart_episode
(
self
):
super
(
PreventStuckPlayer
,
self
)
.
restart_episode
()
self
.
act_que
.
clear
()
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode"""
def
__init__
(
self
,
player
,
limit
):
...
...
@@ -73,8 +82,14 @@ class LimitLengthPlayer(ProxyPlayer):
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
self
.
cnt
+=
1
if
self
.
cnt
=
=
self
.
limit
:
if
self
.
cnt
>
=
self
.
limit
:
isOver
=
True
self
.
player
.
restart_episode
()
if
isOver
:
self
.
cnt
==
0
print
self
.
cnt
self
.
cnt
=
0
return
(
r
,
isOver
)
def
restart_episode
(
self
):
super
(
LimitLengthPlayer
,
self
)
.
restart_episode
()
self
.
cnt
=
0
tensorpack/RL/envbase.py
View file @
dee9d398
...
...
@@ -24,16 +24,20 @@ class RLEnvironment(object):
@
abstractmethod
def
action
(
self
,
act
):
"""
Perform an action
Perform an action
. Will automatically start a new episode if isOver==True
:params act: the action
:returns: (reward, isOver)
"""
@
abstractmethod
def
restart_episode
(
self
):
""" Start a new episode, even if the current hasn't ended """
def
get_stat
(
self
):
"""
return a dict of statistics (e.g., score)
after running for a while
return a dict of statistics (e.g., score)
for all the episodes since last call to reset_stat
"""
return
{}
def
reset_stat
(
self
):
""" reset the statistics counter"""
...
...
@@ -63,6 +67,8 @@ class NaiveRLEnvironment(RLEnvironment):
def
action
(
self
,
act
):
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
def
restart_episode
(
self
):
pass
class
ProxyPlayer
(
RLEnvironment
):
""" Serve as a proxy another player """
...
...
@@ -85,6 +91,5 @@ class ProxyPlayer(RLEnvironment):
def
stats
(
self
):
return
self
.
player
.
stats
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
return
self
.
player
.
play_one_episode
(
self
,
func
,
stat
)
def
restart_episode
(
self
):
self
.
player
.
restart_episode
()
tensorpack/tfutils/summary.py
View file @
dee9d398
...
...
@@ -74,7 +74,7 @@ def add_param_summary(summary_lists):
name
=
p
.
name
for
rgx
,
actions
in
summary_lists
:
if
not
rgx
.
endswith
(
'$'
):
rgx
=
rgx
+
'$'
rgx
=
rgx
+
'
(:0)?
$'
if
re
.
match
(
rgx
,
name
):
for
act
in
actions
:
perform
(
p
,
act
)
...
...
tensorpack/utils/stat.py
View file @
dee9d398
...
...
@@ -21,14 +21,17 @@ class StatCounter(object):
@
property
def
average
(
self
):
assert
len
(
self
.
values
)
return
np
.
mean
(
self
.
values
)
@
property
def
sum
(
self
):
assert
len
(
self
.
values
)
return
np
.
sum
(
self
.
values
)
@
property
def
max
(
self
):
assert
len
(
self
.
values
)
return
max
(
self
.
values
)
class
Accuracy
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment