Commit 3321123f authored by Yuxin Wu's avatar Yuxin Wu

[DQN] remove the simulator thread.

parent 27ba042b
...@@ -6,12 +6,11 @@ import copy ...@@ -6,12 +6,11 @@ import copy
import numpy as np import numpy as np
import threading import threading
from collections import deque, namedtuple from collections import deque, namedtuple
from six.moves import queue, range from six.moves import range
from tensorpack.callbacks.base import Callback from tensorpack.callbacks.base import Callback
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.utils.stats import StatCounter from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_rng, get_tqdm from tensorpack.utils.utils import get_rng, get_tqdm
...@@ -165,24 +164,11 @@ class ExpReplay(DataFlow, Callback): ...@@ -165,24 +164,11 @@ class ExpReplay(DataFlow, Callback):
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized self._init_memory_flag = threading.Event() # tell if memory has been initialized
# a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape, history_len) self.mem = ReplayMemory(memory_size, state_shape, history_len)
self._current_ob = self.player.reset() self._current_ob = self.player.reset()
self._player_scores = StatCounter() self._player_scores = StatCounter()
self._current_game_score = StatCounter() self._current_game_score = StatCounter()
def get_simulator_thread(self):
# spawn a separate thread to run policy
def populate_job_func():
self._populate_job_queue.get()
for _ in range(self.update_frequency):
self._populate_exp()
th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
th.name = "SimulatorThread"
return th
def _init_memory(self): def _init_memory(self):
logger.info("Populating replay memory with epsilon={} ...".format(self.exploration)) logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))
...@@ -263,13 +249,15 @@ class ExpReplay(DataFlow, Callback): ...@@ -263,13 +249,15 @@ class ExpReplay(DataFlow, Callback):
while True: while True:
idx = self.rng.randint( idx = self.rng.randint(
self._populate_job_queue.maxsize * self.update_frequency, 0, len(self.mem) - self.history_len - 1,
len(self.mem) - self.history_len - 1,
size=self.batch_size) size=self.batch_size)
batch_exp = [self.mem.sample(i) for i in idx] batch_exp = [self.mem.sample(i) for i in idx]
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
self._populate_job_queue.put(1)
# execute 4 new actions into memory, after each batch update
for _ in range(self.update_frequency):
self._populate_exp()
# Callback methods: # Callback methods:
def _setup_graph(self): def _setup_graph(self):
...@@ -277,8 +265,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -277,8 +265,6 @@ class ExpReplay(DataFlow, Callback):
def _before_train(self): def _before_train(self):
self._init_memory() self._init_memory()
self._simulator_th = self.get_simulator_thread()
self._simulator_th.start()
def _trigger(self): def _trigger(self):
v = self._player_scores v = self._player_scores
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment