Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
785e01e2
Commit
785e01e2
authored
Jun 13, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
speedup expreplay by 1.3x
parent
f1fc7337
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
4 deletions
+17
-4
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+17
-4
No files found.
tensorpack/RL/expreplay.py
View file @
785e01e2
...
...
@@ -8,9 +8,11 @@ from collections import deque, namedtuple
import
threading
from
tqdm
import
tqdm
import
six
from
six.moves
import
queue
from
..dataflow
import
DataFlow
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..callbacks.base
import
Callback
__all__
=
[
'ExpReplay'
]
...
...
@@ -58,7 +60,7 @@ class ExpReplay(DataFlow, Callback):
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
mem
=
deque
(
maxlen
=
memory_size
)
self
.
rng
=
get_rng
(
self
)
self
.
_init_memory_flag
=
threading
.
Event
()
self
.
_init_memory_flag
=
threading
.
Event
()
# tell if memory has been initialized
def
_init_memory
(
self
):
logger
.
info
(
"Populating replay memory..."
)
...
...
@@ -72,6 +74,8 @@ class ExpReplay(DataFlow, Callback):
with
tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
#from copy import deepcopy # for debug
#self.mem.append(deepcopy(self.mem[0]))
self
.
_populate_exp
()
pbar
.
update
()
self
.
_init_memory_flag
.
set
()
...
...
@@ -111,7 +115,7 @@ class ExpReplay(DataFlow, Callback):
while
True
:
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
#import cv2
#import cv2
# for debug
#def view_state(state, next_state):
#""" for debugging state representation"""
#r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
...
...
@@ -126,8 +130,7 @@ class ExpReplay(DataFlow, Callback):
#view_state(exp[0], exp[1])
yield
self
.
_process_batch
(
batch_exp
)
for
_
in
range
(
self
.
update_frequency
):
self
.
_populate_exp
()
self
.
_populate_job_queue
.
put
(
1
)
def
_sample_one
(
self
):
""" return the transition tuple for
...
...
@@ -170,6 +173,16 @@ class ExpReplay(DataFlow, Callback):
# Callback-related:
def
_before_train
(
self
):
# spawn a separate thread to run policy, can speed up 1.3x
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
1
)
def
populate_job_func
():
self
.
_populate_job_queue
.
get
()
with
self
.
trainer
.
sess
.
as_default
():
for
_
in
range
(
self
.
update_frequency
):
self
.
_populate_exp
()
self
.
_populate_job_th
=
LoopThread
(
populate_job_func
,
False
)
self
.
_populate_job_th
.
start
()
self
.
_init_memory
()
def
_trigger_epoch
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment