Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e5a48033
Commit
e5a48033
authored
Jun 12, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
speedup expreplay a bit
parent
364fe347
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
18 deletions
+27
-18
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+2
-1
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+2
-2
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+21
-12
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-3
No files found.
examples/Atari2600/DQN.py
View file @
e5a48033
...
...
@@ -133,8 +133,8 @@ class Model(ModelDesc):
SummaryGradient
()]
def
predictor
(
self
,
state
):
# TODO change to a multitower predictor for speedup
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
#return self.predict_value.eval(feed_dict={'input_deque:0': [state]})[0]
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
...
...
@@ -206,4 +206,5 @@ if __name__ == '__main__':
config
.
session_init
=
SaverRestore
(
args
.
load
)
SimpleTrainer
(
config
)
.
train
()
#QueueInputTrainer(config).train()
# TODO test if QueueInput affects learning
tensorpack/RL/atari.py
View file @
e5a48033
...
...
@@ -107,7 +107,7 @@ class AtariPlayer(RLEnvironment):
def
current_state
(
self
):
"""
:returns: a gray-scale (h, w, 1) image
:returns: a gray-scale (h, w, 1)
float32
image
"""
ret
=
self
.
_grab_raw_image
()
# max-pooled over the last screen
...
...
@@ -117,7 +117,7 @@ class AtariPlayer(RLEnvironment):
#m = cv2.resize(ret, (1920,1200))
cv2
.
imshow
(
self
.
windowname
,
ret
)
time
.
sleep
(
self
.
viz
)
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],:]
.
astype
(
'float32'
)
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
ret
=
cv2
.
resize
(
ret
,
self
.
image_shape
)
...
...
tensorpack/RL/expreplay.py
View file @
e5a48033
...
...
@@ -5,6 +5,7 @@
import
numpy
as
np
from
collections
import
deque
,
namedtuple
import
threading
from
tqdm
import
tqdm
import
six
...
...
@@ -48,6 +49,7 @@ class ExpReplay(DataFlow, Callback):
if
populate_size
is
not
None
:
logger
.
warn
(
"populate_size in ExpReplay is deprecated in favor of init_memory_size"
)
init_memory_size
=
populate_size
init_memory_size
=
int
(
init_memory_size
)
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
...
...
@@ -56,6 +58,7 @@ class ExpReplay(DataFlow, Callback):
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
mem
=
deque
(
maxlen
=
memory_size
)
self
.
rng
=
get_rng
(
self
)
self
.
_init_memory_flag
=
threading
.
Event
()
def
_init_memory
(
self
):
logger
.
info
(
"Populating replay memory..."
)
...
...
@@ -69,13 +72,17 @@ class ExpReplay(DataFlow, Callback):
with
tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
self
.
_populate_exp
()
from
copy
import
deepcopy
self
.
mem
.
append
(
deepcopy
(
self
.
mem
[
0
]))
#self._populate_exp()
pbar
.
update
()
self
.
_init_memory_flag
.
set
()
def
reset_state
(
self
):
raise
RuntimeError
(
"Don't run me in multiple processes"
)
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
old_s
=
self
.
player
.
current_state
()
if
self
.
rng
.
rand
()
<=
self
.
exploration
:
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
...
...
@@ -101,6 +108,7 @@ class ExpReplay(DataFlow, Callback):
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
get_data
(
self
):
self
.
_init_memory_flag
.
wait
()
# new s is considered useless if isOver==True
while
True
:
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
...
...
@@ -125,27 +133,28 @@ class ExpReplay(DataFlow, Callback):
def
_sample_one
(
self
):
""" return the transition tuple for
[idx, idx+history_len
] -> [idx+1, idx+1+history_len]
[idx, idx+history_len
) -> [idx+1, idx+1+history_len)
it's the transition from state idx+history_len-1 to state idx+history_len
"""
# look for a state to start with
# when x.isOver==True, (x+1).state is of a different episode
idx
=
self
.
rng
.
randint
(
len
(
self
.
mem
)
-
self
.
history_len
-
1
)
start_idx
=
idx
+
self
.
history_len
-
1
samples
=
[
self
.
mem
[
k
]
for
k
in
range
(
idx
,
idx
+
self
.
history_len
+
1
)]
def
concat
(
idx
):
v
=
[
self
.
mem
[
x
]
.
state
for
x
in
range
(
idx
,
idx
+
self
.
history_len
)
]
v
=
[
x
.
state
for
x
in
samples
[
idx
:
idx
+
self
.
history_len
]
]
return
np
.
concatenate
(
v
,
axis
=
2
)
state
=
concat
(
idx
)
next_state
=
concat
(
idx
+
1
)
reward
=
self
.
mem
[
start_idx
]
.
reward
action
=
self
.
mem
[
start_idx
]
.
action
isOver
=
self
.
mem
[
start_idx
]
.
isOver
state
=
concat
(
0
)
next_state
=
concat
(
1
)
start_mem
=
samples
[
-
2
]
reward
,
action
,
isOver
=
start_mem
.
reward
,
start_mem
.
action
,
start_mem
.
isOver
start_idx
=
self
.
history_len
-
1
# zero-fill state before starting
zero_fill
=
False
for
k
in
range
(
1
,
self
.
history_len
):
if
s
elf
.
mem
[
start_idx
-
k
]
.
isOver
:
if
s
amples
[
start_idx
-
k
]
.
isOver
:
zero_fill
=
True
if
zero_fill
:
state
[:,:,
-
k
-
1
]
=
0
...
...
@@ -157,8 +166,8 @@ class ExpReplay(DataFlow, Callback):
state
=
np
.
array
([
e
[
0
]
for
e
in
batch_exp
])
next_state
=
np
.
array
([
e
[
1
]
for
e
in
batch_exp
])
reward
=
np
.
array
([
e
[
2
]
for
e
in
batch_exp
])
action
=
np
.
array
([
e
[
3
]
for
e
in
batch_exp
])
isOver
=
np
.
array
([
e
[
4
]
for
e
in
batch_exp
])
action
=
np
.
array
([
e
[
3
]
for
e
in
batch_exp
]
,
dtype
=
'int8'
)
isOver
=
np
.
array
([
e
[
4
]
for
e
in
batch_exp
]
,
dtype
=
'bool'
)
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
# Callback-related:
...
...
tensorpack/train/trainer.py
View file @
e5a48033
...
...
@@ -84,8 +84,7 @@ class EnqueueThread(threading.Thread):
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
#_, size = self.sess.run([self.op, self.size_op], feed_dict=feed)
#print size
#print self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
...
...
@@ -144,7 +143,7 @@ class QueueInputTrainer(Trainer):
def
_single_tower_grad
(
self
):
""" Get grad and cost for single-tower"""
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_model_inputs
()
self
.
model
.
build_graph
(
model
_inputs
,
True
)
self
.
model
.
build_graph
(
self
.
dequed
_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment