Commit 3321123f authored by Yuxin Wu's avatar Yuxin Wu

[DQN] remove the simulator thread.

parent 27ba042b
......@@ -6,12 +6,11 @@ import copy
import numpy as np
import threading
from collections import deque, namedtuple
from six.moves import queue, range
from six.moves import range
from tensorpack.callbacks.base import Callback
from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_rng, get_tqdm
......@@ -165,24 +164,11 @@ class ExpReplay(DataFlow, Callback):
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
# a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape, history_len)
self._current_ob = self.player.reset()
self._player_scores = StatCounter()
self._current_game_score = StatCounter()
def get_simulator_thread(self):
# spawn a separate thread to run policy
def populate_job_func():
self._populate_job_queue.get()
for _ in range(self.update_frequency):
self._populate_exp()
th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
th.name = "SimulatorThread"
return th
def _init_memory(self):
logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))
......@@ -263,13 +249,15 @@ class ExpReplay(DataFlow, Callback):
while True:
idx = self.rng.randint(
self._populate_job_queue.maxsize * self.update_frequency,
len(self.mem) - self.history_len - 1,
0, len(self.mem) - self.history_len - 1,
size=self.batch_size)
batch_exp = [self.mem.sample(i) for i in idx]
yield self._process_batch(batch_exp)
self._populate_job_queue.put(1)
# execute 4 new actions into memory, after each batch update
for _ in range(self.update_frequency):
self._populate_exp()
# Callback methods:
def _setup_graph(self):
......@@ -277,8 +265,6 @@ class ExpReplay(DataFlow, Callback):
def _before_train(self):
self._init_memory()
self._simulator_th = self.get_simulator_thread()
self._simulator_th.start()
def _trigger(self):
v = self._player_scores
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment