Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
de6d5502
Commit
de6d5502
authored
May 18, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
rl environment base
parent
97dd6c5c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
7 deletions
+33
-7
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+2
-2
examples/Atari2600/exp_replay.py
examples/Atari2600/exp_replay.py
+1
-1
tensorpack/dataflow/dataset/atari.py
tensorpack/dataflow/dataset/atari.py
+5
-4
tensorpack/dataflow/dataset/rlenv.py
tensorpack/dataflow/dataset/rlenv.py
+25
-0
No files found.
examples/Atari2600/DQN.py
View file @
de6d5502
...
@@ -168,7 +168,7 @@ def play_model(model_path, romfile):
...
@@ -168,7 +168,7 @@ def play_model(model_path, romfile):
act
=
1
act
=
1
que
.
append
(
act
)
que
.
append
(
act
)
print
(
act
)
print
(
act
)
_
,
reward
,
isOver
=
player
.
action
(
act
)
reward
,
isOver
=
player
.
action
(
act
)
tot_reward
+=
reward
tot_reward
+=
reward
if
isOver
:
if
isOver
:
print
(
"Total:"
,
tot_reward
)
print
(
"Total:"
,
tot_reward
)
...
@@ -210,7 +210,7 @@ def eval_model_multiprocess(model_path, romfile):
...
@@ -210,7 +210,7 @@ def eval_model_multiprocess(model_path, romfile):
act
=
1
act
=
1
que
.
append
(
act
)
que
.
append
(
act
)
#print(act)
#print(act)
_
,
reward
,
isOver
=
player
.
action
(
act
)
reward
,
isOver
=
player
.
action
(
act
)
tot_reward
+=
reward
tot_reward
+=
reward
if
isOver
:
if
isOver
:
self
.
outq
.
put
(
tot_reward
)
self
.
outq
.
put
(
tot_reward
)
...
...
examples/Atari2600/exp_replay.py
View file @
de6d5502
...
@@ -61,7 +61,7 @@ class AtariExpReplay(DataFlow):
...
@@ -61,7 +61,7 @@ class AtariExpReplay(DataFlow):
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
else
:
act
=
np
.
argmax
(
self
.
predictor
(
old_s
))
# TODO race condition in session?
act
=
np
.
argmax
(
self
.
predictor
(
old_s
))
# TODO race condition in session?
_
,
reward
,
isOver
=
self
.
player
.
action
(
act
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
reward
=
np
.
clip
(
reward
,
-
1
,
2
)
reward
=
np
.
clip
(
reward
,
-
1
,
2
)
s
=
self
.
player
.
current_state
()
s
=
self
.
player
.
current_state
()
...
...
tensorpack/dataflow/dataset/atari.py
View file @
de6d5502
...
@@ -9,6 +9,7 @@ import os
...
@@ -9,6 +9,7 @@ import os
import
cv2
import
cv2
from
collections
import
deque
from
collections
import
deque
from
...utils
import
get_rng
from
...utils
import
get_rng
from
.
import
RLEnvironment
__all__
=
[
'AtariDriver'
,
'AtariPlayer'
]
__all__
=
[
'AtariDriver'
,
'AtariPlayer'
]
...
@@ -86,7 +87,7 @@ class AtariDriver(object):
...
@@ -86,7 +87,7 @@ class AtariDriver(object):
self
.
_reset
()
self
.
_reset
()
return
(
s
,
r
,
isOver
)
return
(
s
,
r
,
isOver
)
class
AtariPlayer
(
objec
t
):
class
AtariPlayer
(
RLEnvironmen
t
):
""" An Atari game player with limited memory and FPS"""
""" An Atari game player with limited memory and FPS"""
def
__init__
(
self
,
driver
,
hist_len
=
4
,
action_repeat
=
4
,
image_shape
=
(
84
,
84
)):
def
__init__
(
self
,
driver
,
hist_len
=
4
,
action_repeat
=
4
,
image_shape
=
(
84
,
84
)):
"""
"""
...
@@ -125,7 +126,7 @@ class AtariPlayer(object):
...
@@ -125,7 +126,7 @@ class AtariPlayer(object):
:returns: (new_frame, reward, isOver)
:returns: (new_frame, reward, isOver)
"""
"""
self
.
last_act
=
act
self
.
last_act
=
act
return
self
.
_
grab
()
return
self
.
_
observe
()
def
_build_state
(
self
):
def
_build_state
(
self
):
assert
len
(
self
.
frames
)
==
self
.
hist_len
assert
len
(
self
.
frames
)
==
self
.
hist_len
...
@@ -133,7 +134,7 @@ class AtariPlayer(object):
...
@@ -133,7 +134,7 @@ class AtariPlayer(object):
m
=
m
.
transpose
([
1
,
2
,
0
])
m
=
m
.
transpose
([
1
,
2
,
0
])
return
m
return
m
def
_
grab
(
self
):
def
_
observe
(
self
):
""" if isOver==True, current_state will return the new episode
""" if isOver==True, current_state will return the new episode
"""
"""
totr
=
0
totr
=
0
...
@@ -146,7 +147,7 @@ class AtariPlayer(object):
...
@@ -146,7 +147,7 @@ class AtariPlayer(object):
self
.
frames
.
append
(
s
)
self
.
frames
.
append
(
s
)
if
isOver
:
if
isOver
:
self
.
restart
()
self
.
restart
()
return
(
s
,
totr
,
isOver
)
return
(
totr
,
isOver
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
a
=
AtariDriver
(
'breakout.bin'
,
viz
=
True
)
a
=
AtariDriver
(
'breakout.bin'
,
viz
=
True
)
...
...
tensorpack/dataflow/dataset/rlenv.py
0 → 100644
View file @
de6d5502
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: rlenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
abstractmethod
,
ABCMeta
__all__
=
[
'RLEnvironment'
]
class
RLEnvironment
(
object
):
__meta__
=
ABCMeta
@
abstractmethod
def
current_state
(
self
):
"""
Observe, return a state representation
"""
@
abstractmethod
def
action
(
self
,
act
):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment