Commit 785e01e2 authored by Yuxin Wu's avatar Yuxin Wu

speedup expreplay by 1.3x

parent f1fc7337
......@@ -8,9 +8,11 @@ from collections import deque, namedtuple
import threading
from tqdm import tqdm
import six
from six.moves import queue
from ..dataflow import DataFlow
from ..utils import *
from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback
__all__ = ['ExpReplay']
......@@ -58,7 +60,7 @@ class ExpReplay(DataFlow, Callback):
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self)
self._init_memory_flag = threading.Event()
self._init_memory_flag = threading.Event() # tell if memory has been initialized
def _init_memory(self):
logger.info("Populating replay memory...")
......@@ -72,6 +74,8 @@ class ExpReplay(DataFlow, Callback):
with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
#from copy import deepcopy # for debug
#self.mem.append(deepcopy(self.mem[0]))
self._populate_exp()
pbar.update()
self._init_memory_flag.set()
......@@ -111,7 +115,7 @@ class ExpReplay(DataFlow, Callback):
while True:
batch_exp = [self._sample_one() for _ in range(self.batch_size)]
#import cv2
#import cv2 # for debug
#def view_state(state, next_state):
#""" for debugging state representation"""
#r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
......@@ -126,8 +130,7 @@ class ExpReplay(DataFlow, Callback):
#view_state(exp[0], exp[1])
yield self._process_batch(batch_exp)
for _ in range(self.update_frequency):
self._populate_exp()
self._populate_job_queue.put(1)
def _sample_one(self):
""" return the transition tuple for
......@@ -170,6 +173,16 @@ class ExpReplay(DataFlow, Callback):
# Callback-related:
def _before_train(self):
# spawn a separate thread to run policy, can speed up 1.3x
self._populate_job_queue = queue.Queue(maxsize=1)
def populate_job_func():
self._populate_job_queue.get()
with self.trainer.sess.as_default():
for _ in range(self.update_frequency):
self._populate_exp()
self._populate_job_th = LoopThread(populate_job_func, False)
self._populate_job_th.start()
self._init_memory()
def _trigger_epoch(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment