Commit 8ce73803 authored by Yuxin Wu's avatar Yuxin Wu

[DQN] DQN with parallel simulators

parent c51ce295
......@@ -27,6 +27,7 @@ MEMORY_SIZE = 1e6
INIT_MEMORY_SIZE = MEMORY_SIZE // 20
STEPS_PER_EPOCH = 100000 // UPDATE_FREQ # each epoch is 100k played frames
EVAL_EPISODE = 50
NUM_PARALLEL_PLAYERS = 3
USE_GYM = False
ENV_NAME = None
......@@ -101,9 +102,11 @@ class Model(DQNModel):
def get_config(model):
global args
expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True),
get_player=lambda: get_player(train=True),
num_parallel_players=NUM_PARALLEL_PLAYERS,
state_shape=model.state_shape,
batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE,
......@@ -134,7 +137,7 @@ def get_config(model):
interp='linear'),
PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['Qvalue'], get_player),
every_k_epochs=10),
every_k_epochs=5 if 'pong' in args.env.lower() else 10), # eval more frequently for easy games
HumanHyperParamSetter('learning_rate'),
],
steps_per_epoch=STEPS_PER_EPOCH,
......
......@@ -20,9 +20,10 @@ Claimed performance in the paper can be reproduced, on several games I've tested
![DQN](curve-breakout.png)
On one GTX 1080Ti, the ALE version took ~3 hours of training to reach 21 (maximum) score on
Pong, ~15 hours of training to reach 400 score on Breakout.
It runs at 50 batches (~3.2k trained frames, 200 seen frames, 800 game frames) per second on GTX 1080Ti.
On one GTX 1080Ti, the ALE version took __~2 hours__ of training to reach 21 (maximum) score on
Pong, __~10 hours__ of training to reach 400 score on Breakout.
It runs at 80 batches (~5.1k trained frames, 320 seen frames, 1.3k game frames) per second on GTX 1080Ti.
This is likely the fastest open source TF implementation of DQN.
## How to use
......
......@@ -3,11 +3,13 @@
# Author: Yuxin Wu
import copy
import itertools
import numpy as np
import threading
from collections import namedtuple
from six.moves import range
from six.moves import queue, range
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback
from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger
......@@ -46,6 +48,8 @@ class ReplayMemory(object):
self._curr_size = 0
self._curr_pos = 0
self.writer_lock = threading.Lock() # a lock to guard writing to the memory
def append(self, exp):
"""
Args:
......@@ -132,7 +136,7 @@ class EnvRunner(object):
self._current_episode = []
self._current_ob = player.reset()
self._current_game_score = StatCounter() # store per-step reward
self._player_scores = StatCounter() # store per-game total score
self.total_scores = [] # store per-game total score
self.rng = get_rng(self)
......@@ -168,12 +172,13 @@ class EnvRunner(object):
self.player.reset()
if flush_experience:
self._player_scores.feed(self._current_game_score.sum)
self.total_scores.append(self._current_game_score.sum)
self._current_game_score.reset()
# TODO lock here if having multiple runner
for exp in self._current_episode:
self.memory.append(exp)
# Ensure that the whole episode of experience is continuous in the replay buffer
with self.memory.writer_lock:
for exp in self._current_episode:
self.memory.append(exp)
self._current_episode.clear()
def recent_state(self):
......@@ -192,6 +197,64 @@ class EnvRunner(object):
return states
class EnvRunnerManager(object):
"""
A class which manages a list of :class:`EnvRunner`.
Its job is to execute them possibly in parallel and aggregate their results.
"""
def __init__(self, env_runners, maximum_staleness):
"""
Args:
env_runners (list[EnvRunner]):
maximum_staleness (int): when >1 environments run in parallel,
the actual stepping of an environment may happen several steps
after calls to `EnvRunnerManager.step()`, in order to achieve better throughput.
"""
assert len(env_runners) > 0
self._runners = env_runners
if len(self._runners) > 1:
# Only use threads when having >1 runners.
self._populate_job_queue = queue.Queue(maxsize=maximum_staleness)
self._threads = [self._create_simulator_thread(i) for i in range(len(self._runners))]
for t in self._threads:
t.start()
def _create_simulator_thread(self, idx):
# spawn a separate thread to run policy
def populate_job_func():
exp = self._populate_job_queue.get()
self._runners[idx].step(exp)
th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
th.name = "SimulatorThread-{}".format(idx)
return th
def step(self, exploration):
"""
Execute one step in any of the runners.
"""
if len(self._runners) > 1:
self._populate_job_queue.put(exploration)
else:
self._runners[0].step(exploration)
def reset_stats(self):
"""
Returns:
mean, max: two stats of the runners, to be added to backend
"""
scores = list(itertools.chain.from_iterable([v.total_scores for v in self._runners]))
for v in self._runners:
v.total_scores.clear()
try:
return np.mean(scores), np.max(scores)
except Exception:
logger.exception("Cannot compute total scores in EnvRunner.")
return None, None
class ExpReplay(DataFlow, Callback):
"""
Implement experience replay in the paper
......@@ -201,13 +264,22 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as a :class:`DataFlow`.
This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This implementation assumes that state is
batch-able, and the network takes batched inputs.
It does the following:
* Spawn `num_parallel_players` environment thread, each running an instance
of the environment with epislon-greedy policy.
* All environment instances writes their experiences to a shared replay
memory buffer.
* Produces batched samples by sampling the replay buffer. After producing
each batch, it executes the environment instances by a total of
`update_frequency` steps.
This implementation assumes that state is batch-able, and the network takes batched inputs.
"""
def __init__(self,
predictor_io_names,
player,
get_player,
num_parallel_players,
state_shape,
batch_size,
memory_size, init_memory_size,
......@@ -217,7 +289,11 @@ class ExpReplay(DataFlow, Callback):
Args:
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
player (gym.Env): the player.
get_player (-> gym.Env): a callable which returns a player.
num_parallel_players (int): number of players to run in parallel.
Standard DQN uses 1.
Parallelism increases speed, but will affect the distribution of
experiences in the replay buffer.
state_shape (tuple):
batch_size (int):
memory_size (int):
......@@ -235,8 +311,6 @@ class ExpReplay(DataFlow, Callback):
if k != 'self':
setattr(self, k, v)
self.exploration = 1.0 # default initial exploration
self.num_actions = player.action_space.n
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
......@@ -248,7 +322,7 @@ class ExpReplay(DataFlow, Callback):
with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
self.env_runner.step(self.exploration)
self.runner.step(self.exploration)
pbar.update()
self._init_memory_flag.set()
......@@ -257,7 +331,7 @@ class ExpReplay(DataFlow, Callback):
from copy import deepcopy
with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < 5:
self.env_runner.step(self.exploration)
self.runner.step(self.exploration)
pbar.update()
while len(self.mem) < self.init_memory_size:
self.mem.append(deepcopy(self.mem._hist[0]))
......@@ -300,43 +374,24 @@ class ExpReplay(DataFlow, Callback):
yield self._process_batch(batch_exp)
# execute 4 new actions into memory, after each batch update
# execute update_freq=4 new actions into memory, after each batch update
for _ in range(self.update_frequency):
self.env_runner.step(self.exploration)
self.runner.step(self.exploration)
# Callback methods:
def _setup_graph(self):
predictor = self.trainer.get_predictor(*self.predictor_io_names)
self.env_runner = EnvRunner(self.player, predictor, self.mem, self.history_len)
self.predictor = self.trainer.get_predictor(*self.predictor_io_names)
def _before_train(self):
env_runners = [
EnvRunner(self.get_player(), self.predictor, self.mem, self.history_len)
for k in range(self.num_parallel_players)
]
self.runner = EnvRunnerManager(env_runners, self.update_frequency * 2)
self._init_memory()
def _trigger(self):
v = self.env_runner._player_scores
try:
mean, max = v.average, v.max
mean, max = self.runner.reset_stats()
if mean is not None:
self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
self.trainer.monitors.put_scalar('expreplay/max_score', max)
except Exception:
logger.exception("Cannot log training scores.")
v.reset()
if __name__ == '__main__':
from .atari import AtariPlayer
import sys
def predictor(x):
np.array([1, 1, 1, 1])
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor,
player=player,
num_actions=player.get_action_space().num_actions(),
populate_size=1001,
history_len=4)
E._init_memory()
for _ in E.get_data():
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
......@@ -47,6 +47,12 @@ class StatCounter(object):
assert len(self._values)
return min(self._values)
def samples(self):
"""
Returns all samples.
"""
return self._values
class RatioCounter(object):
""" A counter to count ratio of something. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment