Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
90a74f02
Commit
90a74f02
authored
May 23, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
stats for RLEnv
parent
48115f68
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
6 deletions
+41
-6
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+2
-2
tensorpack/dataflow/RL.py
tensorpack/dataflow/RL.py
+25
-4
tensorpack/dataflow/dataset/atari.py
tensorpack/dataflow/dataset/atari.py
+14
-0
No files found.
examples/Atari2600/DQN.py
View file @
90a74f02
...
...
@@ -255,8 +255,8 @@ def get_config(romfile):
output
=
output
.
strip
()
output
=
output
[
output
.
find
(
']'
)
+
1
:]
mean
,
maximum
=
re
.
findall
(
'[0-9
\
.]+'
,
output
)
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
maximum
)
self
.
trainer
.
write_scalar_summary
(
'
eval_
mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'
eval_
max_score'
,
maximum
)
return
TrainConfig
(
dataset
=
dataset_train
,
...
...
tensorpack/dataflow/RL.py
View file @
90a74f02
...
...
@@ -6,9 +6,10 @@
from
abc
import
abstractmethod
,
ABCMeta
import
random
import
numpy
as
np
from
collections
import
deque
,
namedtuple
from
collections
import
deque
,
namedtuple
,
defaultdict
from
tqdm
import
tqdm
import
cv2
import
six
from
.base
import
DataFlow
from
tensorpack.utils
import
*
...
...
@@ -26,6 +27,9 @@ Experience = namedtuple('Experience',
class
RLEnvironment
(
object
):
__meta__
=
ABCMeta
def
__init__
(
self
):
self
.
reset_stat
()
@
abstractmethod
def
current_state
(
self
):
"""
...
...
@@ -40,6 +44,16 @@ class RLEnvironment(object):
:returns: (reward, isOver)
"""
@
abstractmethod
def
get_stat
(
self
):
"""
return a dict of statistics (e.g., score) after running for a while
"""
def
reset_stat
(
self
):
""" reset the statistics counter"""
self
.
stats
=
defaultdict
(
list
)
class
NaiveRLEnvironment
(
RLEnvironment
):
""" for testing only"""
def
__init__
(
self
):
...
...
@@ -67,7 +81,9 @@ class ExpReplay(DataFlow, Callback):
exploration
=
1
,
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
reward_clip
=
None
):
reward_clip
=
None
,
new_experience_per_step
=
1
):
"""
:param predictor: callabale. called with a state, return a distribution
:param player: a `RLEnvironment`
...
...
@@ -117,6 +133,7 @@ class ExpReplay(DataFlow, Callback):
idxs
=
self
.
rng
.
randint
(
len
(
self
.
mem
),
size
=
self
.
batch_size
)
batch_exp
=
[
self
.
mem
[
k
]
for
k
in
idxs
]
yield
self
.
_process_batch
(
batch_exp
)
for
_
in
range
(
self
.
new_experience_per_step
):
self
.
_populate_exp
()
def
_process_batch
(
self
,
batch_exp
):
...
...
@@ -144,7 +161,11 @@ class ExpReplay(DataFlow, Callback):
if
self
.
exploration
>
self
.
end_exploration
:
self
.
exploration
-=
self
.
exploration_epoch_anneal
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
exploration
))
stats
=
self
.
player
.
get_stat
()
for
k
,
v
in
six
.
iteritems
(
stats
):
if
isinstance
(
v
,
float
):
self
.
trainer
.
write_scalar_summary
(
'expreplay/'
+
k
,
v
)
self
.
player
.
reset_stat
()
if
__name__
==
'__main__'
:
...
...
tensorpack/dataflow/dataset/atari.py
View file @
90a74f02
...
...
@@ -107,17 +107,20 @@ class AtariPlayer(RLEnvironment):
:param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image
"""
super
(
AtariPlayer
,
self
)
.
__init__
()
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
self
.
last_act
=
0
self
.
frames
=
deque
(
maxlen
=
hist_len
)
self
.
current_accum_score
=
0
self
.
restart
()
def
restart
(
self
):
"""
Restart the game and populate frames with the beginning frame
"""
self
.
current_accum_score
=
0
self
.
frames
.
clear
()
s
=
self
.
driver
.
grab_image
()
...
...
@@ -156,11 +159,22 @@ class AtariPlayer(RLEnvironment):
if
isOver
:
break
s
=
cv2
.
resize
(
s
,
self
.
image_shape
)
self
.
current_accum_score
+=
totr
self
.
frames
.
append
(
s
)
if
isOver
:
self
.
stats
[
'score'
]
.
append
(
self
.
current_accum_score
)
self
.
restart
()
return
(
totr
,
isOver
)
def
get_stat
(
self
):
try
:
print
self
.
stats
return
{
'avg_score'
:
np
.
mean
(
self
.
stats
[
'score'
]),
'max_score'
:
float
(
np
.
max
(
self
.
stats
[
'score'
]))
}
except
ValueError
:
return
{}
if
__name__
==
'__main__'
:
a
=
AtariDriver
(
'breakout.bin'
,
viz
=
True
)
num
=
a
.
get_num_actions
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment