Commit 785e01e2 authored by Yuxin Wu's avatar Yuxin Wu

speedup expreplay by 1.3x

parent f1fc7337
...@@ -8,9 +8,11 @@ from collections import deque, namedtuple ...@@ -8,9 +8,11 @@ from collections import deque, namedtuple
import threading import threading
from tqdm import tqdm from tqdm import tqdm
import six import six
from six.moves import queue
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import * from ..utils import *
from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
__all__ = ['ExpReplay'] __all__ = ['ExpReplay']
...@@ -58,7 +60,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -58,7 +60,7 @@ class ExpReplay(DataFlow, Callback):
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size) self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() self._init_memory_flag = threading.Event() # tell if memory has been initialized
def _init_memory(self): def _init_memory(self):
logger.info("Populating replay memory...") logger.info("Populating replay memory...")
...@@ -72,6 +74,8 @@ class ExpReplay(DataFlow, Callback): ...@@ -72,6 +74,8 @@ class ExpReplay(DataFlow, Callback):
with tqdm(total=self.init_memory_size) as pbar: with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
#from copy import deepcopy # for debug
#self.mem.append(deepcopy(self.mem[0]))
self._populate_exp() self._populate_exp()
pbar.update() pbar.update()
self._init_memory_flag.set() self._init_memory_flag.set()
...@@ -111,7 +115,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -111,7 +115,7 @@ class ExpReplay(DataFlow, Callback):
while True: while True:
batch_exp = [self._sample_one() for _ in range(self.batch_size)] batch_exp = [self._sample_one() for _ in range(self.batch_size)]
#import cv2 #import cv2 # for debug
#def view_state(state, next_state): #def view_state(state, next_state):
#""" for debugging state representation""" #""" for debugging state representation"""
#r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1) #r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
...@@ -126,8 +130,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -126,8 +130,7 @@ class ExpReplay(DataFlow, Callback):
#view_state(exp[0], exp[1]) #view_state(exp[0], exp[1])
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
for _ in range(self.update_frequency): self._populate_job_queue.put(1)
self._populate_exp()
def _sample_one(self): def _sample_one(self):
""" return the transition tuple for """ return the transition tuple for
...@@ -170,6 +173,16 @@ class ExpReplay(DataFlow, Callback): ...@@ -170,6 +173,16 @@ class ExpReplay(DataFlow, Callback):
# Callback-related: # Callback-related:
def _before_train(self): def _before_train(self):
# spawn a separate thread to run policy, can speed up 1.3x
self._populate_job_queue = queue.Queue(maxsize=1)
def populate_job_func():
self._populate_job_queue.get()
with self.trainer.sess.as_default():
for _ in range(self.update_frequency):
self._populate_exp()
self._populate_job_th = LoopThread(populate_job_func, False)
self._populate_job_th.start()
self._init_memory() self._init_memory()
def _trigger_epoch(self): def _trigger_epoch(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment