Commit b6df5567 authored by Yuxin Wu's avatar Yuxin Wu

move expreplay out of RL.

parent 63004976
......@@ -26,6 +26,7 @@ from tensorpack.RL import *
import common
from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer
from expreplay import ExpReplay
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
......@@ -160,7 +161,7 @@ def get_config():
logger.auto_set_dir()
M = Model()
dataset_train = ExpReplay(
expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True),
batch_size=BATCH_SIZE,
......@@ -174,21 +175,22 @@ def get_config():
history_len=FRAME_HISTORY)
return TrainConfig(
dataflow=dataset_train,
dataflow=expreplay,
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
RunOp(lambda: M.update_target_param()),
dataset_train,
expreplay,
StartProcOrThread(expreplay.get_simulator_thread()),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue']), 3),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
],
# save memory for multi-thread evaluator
session_config=get_default_sess_config(0.6),
model=M,
steps_per_epoch=STEP_PER_EPOCH,
# run the simulator on a separate GPU if available
predict_tower=[1] if get_nr_gpu() > 1 else [0],
)
......
......@@ -9,10 +9,10 @@ import threading
import six
from six.moves import queue
from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, get_rng
from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback
from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger, get_tqdm, get_rng
from tensorpack.utils.concurrency import LoopThread
from tensorpack.callbacks.base import Callback
__all__ = ['ExpReplay']
......@@ -66,17 +66,28 @@ class ExpReplay(DataFlow, Callback):
self.mem = deque(maxlen=int(memory_size))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
self._predictor_io_names = predictor_io_names
# TODO just use a semaphore?
# a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5)
def get_simulator_thread(self):
# spawn a separate thread to run policy, can speed up 1.3x
def populate_job_func():
self._populate_job_queue.get()
with self.trainer.sess.as_default():
for _ in range(self.update_frequency):
self._populate_exp()
th = LoopThread(populate_job_func, pausable=False)
th.name = "SimulatorThread"
return th
def _init_memory(self):
logger.info("Populating replay memory...")
logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))
# fill some for the history
old_exploration = self.exploration
self.exploration = 1
for k in range(self.history_len):
self._populate_exp()
self.exploration = old_exploration
with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
......@@ -95,7 +106,7 @@ class ExpReplay(DataFlow, Callback):
act = self.rng.choice(range(self.num_actions))
else:
# build a history state
# XXX assume a state can be representated by one tensor
# assume a state can be representated by one tensor
ss = [old_s]
isOver = False
......@@ -104,12 +115,13 @@ class ExpReplay(DataFlow, Callback):
if hist_exp.isOver:
isOver = True
if isOver:
# fill the beginning of an episode with zeros
ss.append(np.zeros_like(ss[0]))
else:
ss.append(hist_exp.state)
ss.reverse()
ss = np.concatenate(ss, axis=2)
# XXX assume batched network
# assume batched network
q_values = self.predictor([[ss]])[0][0]
act = np.argmax(q_values)
reward, isOver = self.player.action(act)
......@@ -118,8 +130,9 @@ class ExpReplay(DataFlow, Callback):
self.mem.append(Experience(old_s, act, reward, isOver))
def get_data(self):
# wait for memory to be initialized
self._init_memory_flag.wait()
# new s is considered useless if isOver==True
while True:
batch_exp = [self._sample_one() for _ in range(self.batch_size)]
......@@ -140,6 +153,7 @@ class ExpReplay(DataFlow, Callback):
yield self._process_batch(batch_exp)
self._populate_job_queue.put(1)
# new state is considered useless if isOver==True
def _sample_one(self):
""" return the transition tuple for
[idx, idx+history_len) -> [idx+1, idx+1+history_len)
......@@ -173,29 +187,17 @@ class ExpReplay(DataFlow, Callback):
return (state, next_state, reward, action, isOver)
def _process_batch(self, batch_exp):
state = np.array([e[0] for e in batch_exp])
next_state = np.array([e[1] for e in batch_exp])
reward = np.array([e[2] for e in batch_exp])
action = np.array([e[3] for e in batch_exp], dtype='int8')
isOver = np.array([e[4] for e in batch_exp], dtype='bool')
state = np.asarray([e[0] for e in batch_exp])
next_state = np.asarray([e[1] for e in batch_exp])
reward = np.asarray([e[2] for e in batch_exp])
action = np.asarray([e[3] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[4] for e in batch_exp], dtype='bool')
return [state, action, reward, next_state, isOver]
def _setup_graph(self):
self.predictor = self.trainer.get_predict_func(*self._predictor_io_names)
self.predictor = self.trainer.get_predict_func(*self.predictor_io_names)
# Callback-related:
def _before_train(self):
# spawn a separate thread to run policy, can speed up 1.3x
self._populate_job_queue = queue.Queue(maxsize=1)
def populate_job_func():
self._populate_job_queue.get()
with self.trainer.sess.as_default():
for _ in range(self.update_frequency):
self._populate_exp()
self._populate_job_th = LoopThread(populate_job_func, False)
self._populate_job_th.start()
self._init_memory()
def _trigger_epoch(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment