Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
704bee73
Commit
704bee73
authored
May 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move atari driver
parent
17687d5c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
8 deletions
+75
-8
tensorpack/dataflow/dataset/.gitignore
tensorpack/dataflow/dataset/.gitignore
+1
-0
tensorpack/dataflow/dataset/atari.py
tensorpack/dataflow/dataset/atari.py
+69
-7
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+5
-1
No files found.
tensorpack/dataflow/dataset/.gitignore
View file @
704bee73
...
@@ -2,3 +2,4 @@ mnist_data
...
@@ -2,3 +2,4 @@ mnist_data
cifar10_data
cifar10_data
svhn_data
svhn_data
ilsvrc_metadata
ilsvrc_metadata
bsds500_data
tensorpack/
utils
/atari.py
→
tensorpack/
dataflow/dataset
/atari.py
View file @
704bee73
...
@@ -2,19 +2,19 @@
...
@@ -2,19 +2,19 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: atari.py
# File: atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
ale_python_interface
import
ALEInterface
from
ale_python_interface
import
ALEInterface
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
os
import
os
import
cv2
import
cv2
from
.utils
import
get_rng
from
collections
import
deque
from
...utils
import
get_rng
__all__
=
[
'AtariDriver'
]
__all__
=
[
'AtariDriver'
,
'AtariPlayer'
]
class
AtariDriver
(
object
):
class
AtariDriver
(
object
):
"""
"""
A
driver for atari games
.
A
wrapper for atari emulator
.
"""
"""
def
__init__
(
self
,
rom_file
,
frame_skip
=
1
,
viz
=
0
):
def
__init__
(
self
,
rom_file
,
frame_skip
=
1
,
viz
=
0
):
"""
"""
...
@@ -25,7 +25,7 @@ class AtariDriver(object):
...
@@ -25,7 +25,7 @@ class AtariDriver(object):
self
.
ale
=
ALEInterface
()
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
999
))
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
self
.
rng
.
randint
(
0
,
1000
)
))
self
.
ale
.
setInt
(
"frame_skip"
,
frame_skip
)
self
.
ale
.
setInt
(
"frame_skip"
,
frame_skip
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
...
@@ -42,7 +42,7 @@ class AtariDriver(object):
...
@@ -42,7 +42,7 @@ class AtariDriver(object):
def
_grab_raw_image
(
self
):
def
_grab_raw_image
(
self
):
"""
"""
:returns:
a
3-channel image
:returns:
the current
3-channel image
"""
"""
m
=
np
.
zeros
(
self
.
height
*
self
.
width
*
3
,
dtype
=
np
.
uint8
)
m
=
np
.
zeros
(
self
.
height
*
self
.
width
*
3
,
dtype
=
np
.
uint8
)
self
.
ale
.
getScreenRGB
(
m
)
self
.
ale
.
getScreenRGB
(
m
)
...
@@ -50,7 +50,7 @@ class AtariDriver(object):
...
@@ -50,7 +50,7 @@ class AtariDriver(object):
def
grab_image
(
self
):
def
grab_image
(
self
):
"""
"""
:returns: a gray-scale image, max
imum over the last
:returns: a gray-scale image, max
-pooled over the last frame.
"""
"""
now
=
self
.
_grab_raw_image
()
now
=
self
.
_grab_raw_image
()
ret
=
np
.
maximum
(
now
,
self
.
last_image
)
ret
=
np
.
maximum
(
now
,
self
.
last_image
)
...
@@ -82,6 +82,68 @@ class AtariDriver(object):
...
@@ -82,6 +82,68 @@ class AtariDriver(object):
self
.
_reset
()
self
.
_reset
()
return
(
s
,
r
,
isOver
)
return
(
s
,
r
,
isOver
)
class
AtariPlayer
(
object
):
""" An Atari game player with limited memory and FPS"""
def
__init__
(
self
,
driver
,
hist_len
=
4
,
action_repeat
=
4
,
image_shape
=
(
84
,
84
)):
"""
:param driver: an `AtariDriver` instance.
:param hist_len: history(memory) length
:param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image
"""
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
self
.
last_act
=
0
self
.
frames
=
deque
(
maxlen
=
hist_len
)
self
.
restart
()
def
restart
(
self
):
"""
Restart the game and populate frames with the beginning frame
"""
self
.
frames
.
clear
()
s
=
self
.
driver
.
grab_image
()
s
=
cv2
.
resize
(
s
,
self
.
image_shape
)
for
_
in
range
(
self
.
hist_len
):
self
.
frames
.
append
(
s
)
def
current_state
(
self
):
"""
Return a current state of shape `image_shape + (hist_len,)`
"""
return
self
.
_build_state
()
def
action
(
self
,
act
):
"""
Perform an action
:param act: index of the action
:returns: (new_frame, reward, isOver)
"""
self
.
last_act
=
act
return
self
.
_grab
()
def
_build_state
(
self
):
assert
len
(
self
.
frames
)
==
self
.
hist_len
m
=
np
.
array
(
self
.
frames
)
m
=
m
.
transpose
([
1
,
2
,
0
])
return
m
def
_grab
(
self
):
""" if isOver==True, current_state will return the new episode
"""
totr
=
0
for
k
in
range
(
self
.
action_repeat
):
s
,
r
,
isOver
=
self
.
driver
.
next
(
self
.
last_act
)
totr
+=
r
if
isOver
:
break
s
=
cv2
.
resize
(
s
,
self
.
image_shape
)
self
.
frames
.
append
(
s
)
if
isOver
:
self
.
restart
()
return
(
s
,
totr
,
isOver
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
a
=
AtariDriver
(
'breakout.bin'
,
viz
=
True
)
a
=
AtariDriver
(
'breakout.bin'
,
viz
=
True
)
num
=
a
.
get_num_actions
()
num
=
a
.
get_num_actions
()
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
704bee73
...
@@ -55,7 +55,6 @@ def logSoftmax(x):
...
@@ -55,7 +55,6 @@ def logSoftmax(x):
logprob
=
z
-
tf
.
log
(
tf
.
reduce_sum
(
tf
.
exp
(
z
),
1
,
keep_dims
=
True
))
logprob
=
z
-
tf
.
log
(
tf
.
reduce_sum
(
tf
.
exp
(
z
),
1
,
keep_dims
=
True
))
return
logprob
return
logprob
def
class_balanced_binary_class_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
def
class_balanced_binary_class_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
"""
"""
The class-balanced cross entropy loss for binary classification,
The class-balanced cross entropy loss for binary classification,
...
@@ -80,3 +79,8 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
...
@@ -80,3 +79,8 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
cost
=
tf
.
reduce_mean
(
cost
,
name
=
name
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
name
)
return
cost
return
cost
def
print_stat
(
x
):
""" a simple print op.
Use it like: x = print_stat(x)
"""
return
tf
.
Print
(
x
,
[
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment