Commit 31cfcadf authored by Yuxin Wu's avatar Yuxin Wu

[DQN] split the environment runner from expreplay

parent 6d4a77c7
......@@ -108,12 +108,15 @@ def get_config(model):
batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE,
init_exploration=1.0,
update_frequency=UPDATE_FREQ,
history_len=FRAME_HISTORY,
state_dtype=model.state_dtype.as_numpy_dtype
)
# Set to other values if you need a different initial exploration
# (e.g., # if you're resuming a training half-way)
# expreplay.exploration = 1.0
return TrainConfig(
data=QueueInput(expreplay),
model=model,
......
......@@ -5,7 +5,7 @@
import copy
import numpy as np
import threading
from collections import deque, namedtuple
from collections import namedtuple
from six.moves import range
from tensorpack.callbacks.base import Callback
......@@ -30,7 +30,7 @@ class ReplayMemory(object):
self.max_size = int(max_size)
self.state_shape = state_shape
assert len(state_shape) in [1, 2, 3], state_shape
self._output_shape = self.state_shape + (history_len + 1, )
# self._output_shape = self.state_shape + (history_len + 1, )
self.history_len = int(history_len)
self.dtype = dtype
......@@ -45,7 +45,6 @@ class ReplayMemory(object):
self._curr_size = 0
self._curr_pos = 0
self._hist = deque(maxlen=history_len - 1)
def append(self, exp):
"""
......@@ -59,17 +58,6 @@ class ReplayMemory(object):
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._hist.clear()
else:
self._hist.append(exp)
def recent_state(self):
""" return a list of ``hist_len-1`` elements, each of shape ``self.state_shape`` """
lst = list(self._hist)
states = [np.zeros(self.state_shape, dtype=self.dtype)] * (self._hist.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx):
""" return a tuple of (s,r,a,o),
......@@ -118,6 +106,92 @@ class ReplayMemory(object):
self.isOver[pos] = exp.isOver
class EnvRunner(object):
"""
A class which is responsible for
stepping the environment with epsilon-greedy,
and fill the results to experience replay buffer.
"""
def __init__(self, player, predictor, memory, history_len):
"""
Args:
player (gym.Env)
predictor (callable): the model forward function which takes a
state and returns the prediction.
memory (ReplayMemory): the replay memory to store experience to.
history_len (int):
"""
self.player = player
self.num_actions = player.action_space.n
self.predictor = predictor
self.memory = memory
self.state_shape = memory.state_shape
self.dtype = memory.dtype
self.history_len = history_len
self._current_episode = []
self._current_ob = player.reset()
self._current_game_score = StatCounter() # store per-step reward
self._player_scores = StatCounter() # store per-game total score
self.rng = get_rng(self)
def step(self, exploration):
"""
Run the environment for one step.
If the episode ends, store the entire episode to the replay memory.
"""
old_s = self._current_ob
if self.rng.rand() <= exploration:
act = self.rng.choice(range(self.num_actions))
else:
history = self.recent_state()
history.append(old_s)
history = np.stack(history, axis=-1) # state_shape + (Hist,)
# assume batched network
history = np.expand_dims(history, axis=0)
q_values = self.predictor(history)[0][0] # this is the bottleneck
act = np.argmax(q_values)
self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward)
self._current_episode.append(Experience(old_s, act, reward, isOver))
if isOver:
flush_experience = True
if 'ale.lives' in info: # if running Atari, do something special
if info['ale.lives'] != 0:
# only record score and flush experience
# when a whole game is over (not when an episode is over)
flush_experience = False
self.player.reset()
if flush_experience:
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
# TODO lock here if having multiple runner
for exp in self._current_episode:
self.memory.append(exp)
self._current_episode.clear()
def recent_state(self):
"""
Get the recent state (with stacked history) of the environment.
Returns:
a list of ``hist_len-1`` elements, each of shape ``self.state_shape``
"""
expected_len = self.history_len - 1
if len(self._current_episode) >= expected_len:
return [k.state for k in self._current_episode[-expected_len:]]
else:
states = [np.zeros(self.state_shape, dtype=self.dtype)] * (expected_len - len(self._current_episode))
states.extend([k.state for k in self._current_episode])
return states
class ExpReplay(DataFlow, Callback):
"""
Implement experience replay in the paper
......@@ -137,7 +211,6 @@ class ExpReplay(DataFlow, Callback):
state_shape,
batch_size,
memory_size, init_memory_size,
init_exploration,
update_frequency, history_len,
state_dtype='uint8'):
"""
......@@ -146,10 +219,14 @@ class ExpReplay(DataFlow, Callback):
predict Q value from state.
player (gym.Env): the player.
state_shape (tuple):
history_len (int): length of history frames to concat. Zero-filled
initial frames.
batch_size (int):
memory_size (int):
init_memory_size (int):
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
history_len (int): length of history frames to concat. Zero-filled
initial frames.
state_dtype (str):
"""
assert len(state_shape) in [1, 2, 3], state_shape
init_memory_size = int(init_memory_size)
......@@ -157,24 +234,21 @@ class ExpReplay(DataFlow, Callback):
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.exploration = init_exploration
self.exploration = 1.0 # default initial exploration
self.num_actions = player.action_space.n
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
self.mem = ReplayMemory(memory_size, state_shape, history_len)
self._current_ob = self.player.reset()
self._player_scores = StatCounter()
self._current_game_score = StatCounter()
self.mem = ReplayMemory(memory_size, state_shape, self.history_len, dtype=state_dtype)
def _init_memory(self):
logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))
with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
self._populate_exp()
self.env_runner.step(self.exploration)
pbar.update()
self._init_memory_flag.set()
......@@ -183,42 +257,13 @@ class ExpReplay(DataFlow, Callback):
from copy import deepcopy
with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < 5:
self._populate_exp()
self.env_runner.step(self.exploration)
pbar.update()
while len(self.mem) < self.init_memory_size:
self.mem.append(deepcopy(self.mem._hist[0]))
pbar.update()
self._init_memory_flag.set()
def _populate_exp(self):
""" populate a transition by epsilon-greedy"""
old_s = self._current_ob
if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len):
act = self.rng.choice(range(self.num_actions))
else:
# build a history state
history = self.mem.recent_state()
history.append(old_s)
history = np.stack(history, axis=-1) # state_shape + (Hist,)
history = np.expand_dims(history, axis=0)
# assume batched network
q_values = self.predictor(history)[0][0] # this is the bottleneck
act = np.argmax(q_values)
self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward)
if isOver:
if 'ale.lives' in info: # if running Atari, do something special for logging:
if info['ale.lives'] == 0:
# only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
else:
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
self.player.reset()
self.mem.append(Experience(old_s, act, reward, isOver))
def _debug_sample(self, sample):
import cv2
......@@ -257,17 +302,18 @@ class ExpReplay(DataFlow, Callback):
# execute 4 new actions into memory, after each batch update
for _ in range(self.update_frequency):
self._populate_exp()
self.env_runner.step(self.exploration)
# Callback methods:
def _setup_graph(self):
self.predictor = self.trainer.get_predictor(*self.predictor_io_names)
predictor = self.trainer.get_predictor(*self.predictor_io_names)
self.env_runner = EnvRunner(self.player, predictor, self.mem, self.history_len)
def _before_train(self):
self._init_memory()
def _trigger(self):
v = self._player_scores
v = self.env_runner._player_scores
try:
mean, max = v.average, v.max
self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment