Commit 8ce73803 authored by Yuxin Wu's avatar Yuxin Wu

[DQN] DQN with parallel simulators

parent c51ce295
...@@ -27,6 +27,7 @@ MEMORY_SIZE = 1e6 ...@@ -27,6 +27,7 @@ MEMORY_SIZE = 1e6
INIT_MEMORY_SIZE = MEMORY_SIZE // 20 INIT_MEMORY_SIZE = MEMORY_SIZE // 20
STEPS_PER_EPOCH = 100000 // UPDATE_FREQ # each epoch is 100k played frames STEPS_PER_EPOCH = 100000 // UPDATE_FREQ # each epoch is 100k played frames
EVAL_EPISODE = 50 EVAL_EPISODE = 50
NUM_PARALLEL_PLAYERS = 3
USE_GYM = False USE_GYM = False
ENV_NAME = None ENV_NAME = None
...@@ -101,9 +102,11 @@ class Model(DQNModel): ...@@ -101,9 +102,11 @@ class Model(DQNModel):
def get_config(model): def get_config(model):
global args
expreplay = ExpReplay( expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']), predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True), get_player=lambda: get_player(train=True),
num_parallel_players=NUM_PARALLEL_PLAYERS,
state_shape=model.state_shape, state_shape=model.state_shape,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
...@@ -134,7 +137,7 @@ def get_config(model): ...@@ -134,7 +137,7 @@ def get_config(model):
interp='linear'), interp='linear'),
PeriodicTrigger(Evaluator( PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['Qvalue'], get_player), EVAL_EPISODE, ['state'], ['Qvalue'], get_player),
every_k_epochs=10), every_k_epochs=5 if 'pong' in args.env.lower() else 10), # eval more frequently for easy games
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
], ],
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
......
...@@ -20,9 +20,10 @@ Claimed performance in the paper can be reproduced, on several games I've tested ...@@ -20,9 +20,10 @@ Claimed performance in the paper can be reproduced, on several games I've tested
![DQN](curve-breakout.png) ![DQN](curve-breakout.png)
On one GTX 1080Ti, the ALE version took ~3 hours of training to reach 21 (maximum) score on On one GTX 1080Ti, the ALE version took __~2 hours__ of training to reach 21 (maximum) score on
Pong, ~15 hours of training to reach 400 score on Breakout. Pong, __~10 hours__ of training to reach 400 score on Breakout.
It runs at 50 batches (~3.2k trained frames, 200 seen frames, 800 game frames) per second on GTX 1080Ti. It runs at 80 batches (~5.1k trained frames, 320 seen frames, 1.3k game frames) per second on GTX 1080Ti.
This is likely the fastest open source TF implementation of DQN.
## How to use ## How to use
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
# Author: Yuxin Wu # Author: Yuxin Wu
import copy import copy
import itertools
import numpy as np import numpy as np
import threading import threading
from collections import namedtuple from collections import namedtuple
from six.moves import range from six.moves import queue, range
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback from tensorpack.callbacks.base import Callback
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger from tensorpack.utils import logger
...@@ -46,6 +48,8 @@ class ReplayMemory(object): ...@@ -46,6 +48,8 @@ class ReplayMemory(object):
self._curr_size = 0 self._curr_size = 0
self._curr_pos = 0 self._curr_pos = 0
self.writer_lock = threading.Lock() # a lock to guard writing to the memory
def append(self, exp): def append(self, exp):
""" """
Args: Args:
...@@ -132,7 +136,7 @@ class EnvRunner(object): ...@@ -132,7 +136,7 @@ class EnvRunner(object):
self._current_episode = [] self._current_episode = []
self._current_ob = player.reset() self._current_ob = player.reset()
self._current_game_score = StatCounter() # store per-step reward self._current_game_score = StatCounter() # store per-step reward
self._player_scores = StatCounter() # store per-game total score self.total_scores = [] # store per-game total score
self.rng = get_rng(self) self.rng = get_rng(self)
...@@ -168,12 +172,13 @@ class EnvRunner(object): ...@@ -168,12 +172,13 @@ class EnvRunner(object):
self.player.reset() self.player.reset()
if flush_experience: if flush_experience:
self._player_scores.feed(self._current_game_score.sum) self.total_scores.append(self._current_game_score.sum)
self._current_game_score.reset() self._current_game_score.reset()
# TODO lock here if having multiple runner # Ensure that the whole episode of experience is continuous in the replay buffer
for exp in self._current_episode: with self.memory.writer_lock:
self.memory.append(exp) for exp in self._current_episode:
self.memory.append(exp)
self._current_episode.clear() self._current_episode.clear()
def recent_state(self): def recent_state(self):
...@@ -192,6 +197,64 @@ class EnvRunner(object): ...@@ -192,6 +197,64 @@ class EnvRunner(object):
return states return states
class EnvRunnerManager(object):
"""
A class which manages a list of :class:`EnvRunner`.
Its job is to execute them possibly in parallel and aggregate their results.
"""
def __init__(self, env_runners, maximum_staleness):
"""
Args:
env_runners (list[EnvRunner]):
maximum_staleness (int): when >1 environments run in parallel,
the actual stepping of an environment may happen several steps
after calls to `EnvRunnerManager.step()`, in order to achieve better throughput.
"""
assert len(env_runners) > 0
self._runners = env_runners
if len(self._runners) > 1:
# Only use threads when having >1 runners.
self._populate_job_queue = queue.Queue(maxsize=maximum_staleness)
self._threads = [self._create_simulator_thread(i) for i in range(len(self._runners))]
for t in self._threads:
t.start()
def _create_simulator_thread(self, idx):
# spawn a separate thread to run policy
def populate_job_func():
exp = self._populate_job_queue.get()
self._runners[idx].step(exp)
th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
th.name = "SimulatorThread-{}".format(idx)
return th
def step(self, exploration):
"""
Execute one step in any of the runners.
"""
if len(self._runners) > 1:
self._populate_job_queue.put(exploration)
else:
self._runners[0].step(exploration)
def reset_stats(self):
"""
Returns:
mean, max: two stats of the runners, to be added to backend
"""
scores = list(itertools.chain.from_iterable([v.total_scores for v in self._runners]))
for v in self._runners:
v.total_scores.clear()
try:
return np.mean(scores), np.max(scores)
except Exception:
logger.exception("Cannot compute total scores in EnvRunner.")
return None, None
class ExpReplay(DataFlow, Callback): class ExpReplay(DataFlow, Callback):
""" """
Implement experience replay in the paper Implement experience replay in the paper
...@@ -201,13 +264,22 @@ class ExpReplay(DataFlow, Callback): ...@@ -201,13 +264,22 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as a :class:`DataFlow`. This implementation provides the interface as a :class:`DataFlow`.
This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching). This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This implementation assumes that state is It does the following:
batch-able, and the network takes batched inputs. * Spawn `num_parallel_players` environment thread, each running an instance
of the environment with epislon-greedy policy.
* All environment instances writes their experiences to a shared replay
memory buffer.
* Produces batched samples by sampling the replay buffer. After producing
each batch, it executes the environment instances by a total of
`update_frequency` steps.
This implementation assumes that state is batch-able, and the network takes batched inputs.
""" """
def __init__(self, def __init__(self,
predictor_io_names, predictor_io_names,
player, get_player,
num_parallel_players,
state_shape, state_shape,
batch_size, batch_size,
memory_size, init_memory_size, memory_size, init_memory_size,
...@@ -217,7 +289,11 @@ class ExpReplay(DataFlow, Callback): ...@@ -217,7 +289,11 @@ class ExpReplay(DataFlow, Callback):
Args: Args:
predictor_io_names (tuple of list of str): input/output names to predictor_io_names (tuple of list of str): input/output names to
predict Q value from state. predict Q value from state.
player (gym.Env): the player. get_player (-> gym.Env): a callable which returns a player.
num_parallel_players (int): number of players to run in parallel.
Standard DQN uses 1.
Parallelism increases speed, but will affect the distribution of
experiences in the replay buffer.
state_shape (tuple): state_shape (tuple):
batch_size (int): batch_size (int):
memory_size (int): memory_size (int):
...@@ -235,8 +311,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -235,8 +311,6 @@ class ExpReplay(DataFlow, Callback):
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.exploration = 1.0 # default initial exploration self.exploration = 1.0 # default initial exploration
self.num_actions = player.action_space.n
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized self._init_memory_flag = threading.Event() # tell if memory has been initialized
...@@ -248,7 +322,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -248,7 +322,7 @@ class ExpReplay(DataFlow, Callback):
with get_tqdm(total=self.init_memory_size) as pbar: with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
self.env_runner.step(self.exploration) self.runner.step(self.exploration)
pbar.update() pbar.update()
self._init_memory_flag.set() self._init_memory_flag.set()
...@@ -257,7 +331,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -257,7 +331,7 @@ class ExpReplay(DataFlow, Callback):
from copy import deepcopy from copy import deepcopy
with get_tqdm(total=self.init_memory_size) as pbar: with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < 5: while len(self.mem) < 5:
self.env_runner.step(self.exploration) self.runner.step(self.exploration)
pbar.update() pbar.update()
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
self.mem.append(deepcopy(self.mem._hist[0])) self.mem.append(deepcopy(self.mem._hist[0]))
...@@ -300,43 +374,24 @@ class ExpReplay(DataFlow, Callback): ...@@ -300,43 +374,24 @@ class ExpReplay(DataFlow, Callback):
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
# execute 4 new actions into memory, after each batch update # execute update_freq=4 new actions into memory, after each batch update
for _ in range(self.update_frequency): for _ in range(self.update_frequency):
self.env_runner.step(self.exploration) self.runner.step(self.exploration)
# Callback methods: # Callback methods:
def _setup_graph(self): def _setup_graph(self):
predictor = self.trainer.get_predictor(*self.predictor_io_names) self.predictor = self.trainer.get_predictor(*self.predictor_io_names)
self.env_runner = EnvRunner(self.player, predictor, self.mem, self.history_len)
def _before_train(self): def _before_train(self):
env_runners = [
EnvRunner(self.get_player(), self.predictor, self.mem, self.history_len)
for k in range(self.num_parallel_players)
]
self.runner = EnvRunnerManager(env_runners, self.update_frequency * 2)
self._init_memory() self._init_memory()
def _trigger(self): def _trigger(self):
v = self.env_runner._player_scores mean, max = self.runner.reset_stats()
try: if mean is not None:
mean, max = v.average, v.max
self.trainer.monitors.put_scalar('expreplay/mean_score', mean) self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
self.trainer.monitors.put_scalar('expreplay/max_score', max) self.trainer.monitors.put_scalar('expreplay/max_score', max)
except Exception:
logger.exception("Cannot log training scores.")
v.reset()
if __name__ == '__main__':
from .atari import AtariPlayer
import sys
def predictor(x):
np.array([1, 1, 1, 1])
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor,
player=player,
num_actions=player.get_action_space().num_actions(),
populate_size=1001,
history_len=4)
E._init_memory()
for _ in E.get_data():
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
...@@ -47,6 +47,12 @@ class StatCounter(object): ...@@ -47,6 +47,12 @@ class StatCounter(object):
assert len(self._values) assert len(self._values)
return min(self._values) return min(self._values)
def samples(self):
"""
Returns all samples.
"""
return self._values
class RatioCounter(object): class RatioCounter(object):
""" A counter to count ratio of something. """ """ A counter to count ratio of something. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment