Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0c5e39eb
Commit
0c5e39eb
authored
May 24, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ataridriver as an rlenv
parent
d5d7270a
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
30 additions
and
24 deletions
+30
-24
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-0
opt-requirements.txt
opt-requirements.txt
+3
-0
requirements.txt
requirements.txt
+0
-2
tensorpack/dataflow/RL.py
tensorpack/dataflow/RL.py
+1
-0
tensorpack/dataflow/dataset/atari.py
tensorpack/dataflow/dataset/atari.py
+21
-21
tensorpack/predict.py
tensorpack/predict.py
+1
-1
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+3
-0
No files found.
examples/Atari2600/DQN.py
View file @
0c5e39eb
...
...
@@ -37,6 +37,7 @@ NUM_ACTIONS = None
FRAME_HISTORY
=
4
ACTION_REPEAT
=
3
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
# HEIGHT_RANGE = (28, -8) # for pong
GAMMA
=
0.99
BATCH_SIZE
=
32
...
...
opt-requirements.txt
0 → 100644
View file @
0c5e39eb
nltk
h5py
pyzmq
requirements.txt
View file @
0c5e39eb
...
...
@@ -2,6 +2,4 @@ termcolor
pillow
scipy
tqdm
h5py
nltk
dill
tensorpack/dataflow/RL.py
View file @
0c5e39eb
...
...
@@ -126,6 +126,7 @@ class ExpReplay(DataFlow, Callback):
#print act, reward
#view_state(s)
# s is considered useless if isOver==True
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
s
,
isOver
))
def
get_data
(
self
):
...
...
tensorpack/dataflow/dataset/atari.py
View file @
0c5e39eb
...
...
@@ -18,12 +18,11 @@ except ImportError:
__all__
=
[
'AtariDriver'
,
'AtariPlayer'
]
class
AtariDriver
(
objec
t
):
class
AtariDriver
(
RLEnvironmen
t
):
"""
A wrapper for atari emulator.
"""
def
__init__
(
self
,
rom_file
,
frame_skip
=
1
,
viz
=
0
,
height_range
=
(
None
,
None
)):
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
)):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
...
...
@@ -33,7 +32,8 @@ class AtariDriver(object):
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
self
.
rng
.
randint
(
0
,
1000
)))
self
.
ale
.
setInt
(
"frame_skip"
,
frame_skip
)
self
.
ale
.
setInt
(
"frame_skip"
,
1
)
self
.
ale
.
setBool
(
'color_averaging'
,
True
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
...
...
@@ -46,10 +46,10 @@ class AtariDriver(object):
cv2
.
startWindowThread
()
cv2
.
namedWindow
(
self
.
romname
)
self
.
_reset
()
self
.
last_image
=
self
.
_grab_raw_image
()
self
.
framenum
=
0
self
.
height_range
=
height_range
self
.
framenum
=
0
self
.
_reset
()
def
_grab_raw_image
(
self
):
"""
...
...
@@ -59,14 +59,11 @@ class AtariDriver(object):
self
.
ale
.
getScreenRGB
(
m
)
return
m
.
reshape
((
self
.
height
,
self
.
width
,
3
))
def
grab_imag
e
(
self
):
def
current_stat
e
(
self
):
"""
:returns: a gray-scale image, max-pooled over the last frame.
"""
now
=
self
.
_grab_raw_image
()
ret
=
np
.
maximum
(
now
,
self
.
last_image
)
self
.
last_image
=
now
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
# several online repos all use this
if
self
.
viz
:
if
isinstance
(
self
.
viz
,
float
):
cv2
.
imshow
(
self
.
romname
,
ret
)
...
...
@@ -74,6 +71,7 @@ class AtariDriver(object):
else
:
cv2
.
imwrite
(
"{}/{:06d}.jpg"
.
format
(
self
.
viz
,
self
.
framenum
),
ret
)
self
.
framenum
+=
1
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_BGR2YUV
)[:,:,
0
]
return
ret
...
...
@@ -86,17 +84,16 @@ class AtariDriver(object):
def
_reset
(
self
):
self
.
ale
.
reset_game
()
def
next
(
self
,
act
):
def
action
(
self
,
act
):
"""
:param act: an index of the action
:returns: (
next_image,
reward, isOver)
:returns: (reward, isOver)
"""
r
=
self
.
ale
.
act
(
self
.
actions
[
act
])
s
=
self
.
grab_image
()
isOver
=
self
.
ale
.
game_over
()
if
isOver
:
self
.
_reset
()
return
(
s
,
r
,
isOver
)
return
(
r
,
isOver
)
class
AtariPlayer
(
RLEnvironment
):
""" An Atari game player with limited memory and FPS"""
...
...
@@ -122,7 +119,7 @@ class AtariPlayer(RLEnvironment):
"""
self
.
current_accum_score
=
0
self
.
frames
.
clear
()
s
=
self
.
driver
.
grab_imag
e
()
s
=
self
.
driver
.
current_stat
e
()
s
=
cv2
.
resize
(
s
,
self
.
image_shape
)
for
_
in
range
(
self
.
hist_len
):
...
...
@@ -138,7 +135,7 @@ class AtariPlayer(RLEnvironment):
"""
Perform an action
:param act: index of the action
:returns: (
new_frame,
reward, isOver)
:returns: (reward, isOver)
"""
self
.
last_act
=
act
return
self
.
_observe
()
...
...
@@ -154,7 +151,8 @@ class AtariPlayer(RLEnvironment):
"""
totr
=
0
for
k
in
range
(
self
.
action_repeat
):
s
,
r
,
isOver
=
self
.
driver
.
next
(
self
.
last_act
)
r
,
isOver
=
self
.
driver
.
action
(
self
.
last_act
)
s
=
self
.
driver
.
current_state
()
totr
+=
r
if
isOver
:
break
...
...
@@ -174,7 +172,8 @@ class AtariPlayer(RLEnvironment):
return
{}
if
__name__
==
'__main__'
:
a
=
AtariDriver
(
'breakout.bin'
,
viz
=
True
)
import
sys
a
=
AtariDriver
(
sys
.
argv
[
1
],
viz
=
0.01
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_num_actions
()
rng
=
get_rng
(
num
)
import
time
...
...
@@ -182,7 +181,8 @@ if __name__ == '__main__':
#im = a.grab_image()
#cv2.imshow(a.romname, im)
act
=
rng
.
choice
(
range
(
num
))
s
,
r
,
o
=
a
.
next
(
act
)
time
.
sleep
(
0.1
)
print
act
r
,
o
=
a
.
action
(
act
)
#time.sleep(0.1)
print
(
r
,
o
)
tensorpack/predict.py
View file @
0c5e39eb
...
...
@@ -118,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process):
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
else
:
logger
.
info
(
"Worker {} uses CPU"
.
format
(
self
.
idx
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'
0
'
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
(),
tf
.
device
(
'/gpu:0'
if
self
.
gpuid
>=
0
else
'/cpu:0'
):
if
self
.
idx
!=
0
:
...
...
tensorpack/tfutils/summary.py
View file @
0c5e39eb
...
...
@@ -34,6 +34,9 @@ def add_activation_summary(x, name=None):
name
=
x
.
name
tf
.
histogram_summary
(
name
+
'/activation'
,
x
)
tf
.
scalar_summary
(
name
+
'/activation_sparsity'
,
tf
.
nn
.
zero_fraction
(
x
))
tf
.
scalar_summary
(
name
+
'/activation_rms'
,
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
))))
def
add_param_summary
(
summary_lists
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment