Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
961b0ee4
Commit
961b0ee4
authored
May 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move exp_replay
parent
fc0b965a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
13 deletions
+26
-13
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+4
-2
tensorpack/__init__.py
tensorpack/__init__.py
+3
-2
tensorpack/dataflow/RL.py
tensorpack/dataflow/RL.py
+19
-9
No files found.
examples/Atari2600/DQN.py
View file @
961b0ee4
...
@@ -245,15 +245,17 @@ def get_config(romfile):
...
@@ -245,15 +245,17 @@ def get_config(romfile):
global
NUM_ACTIONS
global
NUM_ACTIONS
NUM_ACTIONS
=
driver
.
get_num_actions
()
NUM_ACTIONS
=
driver
.
get_num_actions
()
dataset_train
=
Atari
ExpReplay
(
dataset_train
=
ExpReplay
(
predictor
=
current_predictor
,
predictor
=
current_predictor
,
player
=
AtariPlayer
(
player
=
AtariPlayer
(
driver
,
hist_len
=
FRAME_HISTORY
,
driver
,
hist_len
=
FRAME_HISTORY
,
action_repeat
=
ACTION_REPEAT
),
action_repeat
=
ACTION_REPEAT
),
num_actions
=
NUM_ACTIONS
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
populate_size
=
INIT_MEMORY_SIZE
,
populate_size
=
INIT_MEMORY_SIZE
,
exploration
=
INIT_EXPLORATION
)
exploration
=
INIT_EXPLORATION
,
reward_clip
=
(
-
1
,
2
))
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
...
...
tensorpack/__init__.py
View file @
961b0ee4
...
@@ -2,8 +2,9 @@
...
@@ -2,8 +2,9 @@
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
import
numpy
# avoid https://github.com/tensorflow/tensorflow/issues/2034
import
cv2
# fix https://github.com/tensorflow/tensorflow/issues/1924
import
cv2
# avoid https://github.com/tensorflow/tensorflow/issues/1924
from
.
import
models
from
.
import
models
from
.
import
train
from
.
import
train
from
.
import
utils
from
.
import
utils
...
...
examples/Atari2600/exp_replay
.py
→
tensorpack/dataflow/RL
.py
View file @
961b0ee4
#!/usr/bin/env python2
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File:
exp_replay
.py
# File:
RL
.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
tensorpack.dataflow
import
*
from
.base
import
DataFlow
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
random
import
random
import
numpy
as
np
import
numpy
as
np
import
cv2
import
cv2
from
collections
import
deque
,
namedtuple
from
collections
import
deque
,
namedtuple
"""
Implement RL-related data preprocessing
"""
__all__
=
[
'ExpReplay'
]
Experience
=
namedtuple
(
'Experience'
,
Experience
=
namedtuple
(
'Experience'
,
[
'state'
,
'action'
,
'reward'
,
'next'
,
'isOver'
])
[
'state'
,
'action'
,
'reward'
,
'next'
,
'isOver'
])
def
view_state
(
state
):
def
view_state
(
state
):
# for debug
r
=
np
.
concatenate
([
state
[:,:,
k
]
for
k
in
range
(
state
.
shape
[
2
])],
axis
=
1
)
r
=
np
.
concatenate
([
state
[:,:,
k
]
for
k
in
range
(
state
.
shape
[
2
])],
axis
=
1
)
print
r
.
shape
print
r
.
shape
cv2
.
imshow
(
"state"
,
r
)
cv2
.
imshow
(
"state"
,
r
)
cv2
.
waitKey
()
cv2
.
waitKey
()
class
Atari
ExpReplay
(
DataFlow
):
class
ExpReplay
(
DataFlow
):
"""
"""
Implement experience replay
Implement experience replay
.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
predictor
,
predictor
,
player
,
player
,
num_actions
,
memory_size
=
1e6
,
memory_size
=
1e6
,
batch_size
=
32
,
batch_size
=
32
,
populate_size
=
50000
,
populate_size
=
50000
,
exploration
=
1
):
exploration
=
1
,
reward_clip
=
None
):
"""
"""
:param predictor: callabale. called with a state, return a distribution
:param predictor: callabale. called with a state, return a distribution
:param player: a `RLEnvironment`
"""
"""
for
k
,
v
in
locals
()
.
items
():
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
self
.
num_actions
=
self
.
player
.
driver
.
get_num_actions
()
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
mem
=
deque
(
maxlen
=
memory_size
)
self
.
mem
=
deque
(
maxlen
=
memory_size
)
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
...
@@ -62,7 +70,8 @@ class AtariExpReplay(DataFlow):
...
@@ -62,7 +70,8 @@ class AtariExpReplay(DataFlow):
else
:
else
:
act
=
np
.
argmax
(
self
.
predictor
(
old_s
))
# TODO race condition in session?
act
=
np
.
argmax
(
self
.
predictor
(
old_s
))
# TODO race condition in session?
reward
,
isOver
=
self
.
player
.
action
(
act
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
reward
=
np
.
clip
(
reward
,
-
1
,
2
)
if
self
.
reward_clip
:
reward
=
np
.
clip
(
reward
,
self
.
reward_clip
[
0
],
self
.
reward_clip
[
1
])
s
=
self
.
player
.
current_state
()
s
=
self
.
player
.
current_state
()
#print act, reward
#print act, reward
...
@@ -94,6 +103,7 @@ class AtariExpReplay(DataFlow):
...
@@ -94,6 +103,7 @@ class AtariExpReplay(DataFlow):
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
.
initialized
=
False
predictor
.
initialized
=
False
E
=
AtariExpReplay
(
predictor
,
predictor
,
E
=
AtariExpReplay
(
predictor
,
predictor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment