Commit 31cfcadf authored by Yuxin Wu's avatar Yuxin Wu

[DQN] split the environment runner from expreplay

parent 6d4a77c7
...@@ -108,12 +108,15 @@ def get_config(model): ...@@ -108,12 +108,15 @@ def get_config(model):
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
init_exploration=1.0,
update_frequency=UPDATE_FREQ, update_frequency=UPDATE_FREQ,
history_len=FRAME_HISTORY, history_len=FRAME_HISTORY,
state_dtype=model.state_dtype.as_numpy_dtype state_dtype=model.state_dtype.as_numpy_dtype
) )
# Set to other values if you need a different initial exploration
# (e.g., # if you're resuming a training half-way)
# expreplay.exploration = 1.0
return TrainConfig( return TrainConfig(
data=QueueInput(expreplay), data=QueueInput(expreplay),
model=model, model=model,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import copy import copy
import numpy as np import numpy as np
import threading import threading
from collections import deque, namedtuple from collections import namedtuple
from six.moves import range from six.moves import range
from tensorpack.callbacks.base import Callback from tensorpack.callbacks.base import Callback
...@@ -30,7 +30,7 @@ class ReplayMemory(object): ...@@ -30,7 +30,7 @@ class ReplayMemory(object):
self.max_size = int(max_size) self.max_size = int(max_size)
self.state_shape = state_shape self.state_shape = state_shape
assert len(state_shape) in [1, 2, 3], state_shape assert len(state_shape) in [1, 2, 3], state_shape
self._output_shape = self.state_shape + (history_len + 1, ) # self._output_shape = self.state_shape + (history_len + 1, )
self.history_len = int(history_len) self.history_len = int(history_len)
self.dtype = dtype self.dtype = dtype
...@@ -45,7 +45,6 @@ class ReplayMemory(object): ...@@ -45,7 +45,6 @@ class ReplayMemory(object):
self._curr_size = 0 self._curr_size = 0
self._curr_pos = 0 self._curr_pos = 0
self._hist = deque(maxlen=history_len - 1)
def append(self, exp): def append(self, exp):
""" """
...@@ -59,17 +58,6 @@ class ReplayMemory(object): ...@@ -59,17 +58,6 @@ class ReplayMemory(object):
else: else:
self._assign(self._curr_pos, exp) self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._hist.clear()
else:
self._hist.append(exp)
def recent_state(self):
""" return a list of ``hist_len-1`` elements, each of shape ``self.state_shape`` """
lst = list(self._hist)
states = [np.zeros(self.state_shape, dtype=self.dtype)] * (self._hist.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx): def sample(self, idx):
""" return a tuple of (s,r,a,o), """ return a tuple of (s,r,a,o),
...@@ -118,6 +106,92 @@ class ReplayMemory(object): ...@@ -118,6 +106,92 @@ class ReplayMemory(object):
self.isOver[pos] = exp.isOver self.isOver[pos] = exp.isOver
class EnvRunner(object):
"""
A class which is responsible for
stepping the environment with epsilon-greedy,
and fill the results to experience replay buffer.
"""
def __init__(self, player, predictor, memory, history_len):
"""
Args:
player (gym.Env)
predictor (callable): the model forward function which takes a
state and returns the prediction.
memory (ReplayMemory): the replay memory to store experience to.
history_len (int):
"""
self.player = player
self.num_actions = player.action_space.n
self.predictor = predictor
self.memory = memory
self.state_shape = memory.state_shape
self.dtype = memory.dtype
self.history_len = history_len
self._current_episode = []
self._current_ob = player.reset()
self._current_game_score = StatCounter() # store per-step reward
self._player_scores = StatCounter() # store per-game total score
self.rng = get_rng(self)
def step(self, exploration):
"""
Run the environment for one step.
If the episode ends, store the entire episode to the replay memory.
"""
old_s = self._current_ob
if self.rng.rand() <= exploration:
act = self.rng.choice(range(self.num_actions))
else:
history = self.recent_state()
history.append(old_s)
history = np.stack(history, axis=-1) # state_shape + (Hist,)
# assume batched network
history = np.expand_dims(history, axis=0)
q_values = self.predictor(history)[0][0] # this is the bottleneck
act = np.argmax(q_values)
self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward)
self._current_episode.append(Experience(old_s, act, reward, isOver))
if isOver:
flush_experience = True
if 'ale.lives' in info: # if running Atari, do something special
if info['ale.lives'] != 0:
# only record score and flush experience
# when a whole game is over (not when an episode is over)
flush_experience = False
self.player.reset()
if flush_experience:
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
# TODO lock here if having multiple runner
for exp in self._current_episode:
self.memory.append(exp)
self._current_episode.clear()
def recent_state(self):
"""
Get the recent state (with stacked history) of the environment.
Returns:
a list of ``hist_len-1`` elements, each of shape ``self.state_shape``
"""
expected_len = self.history_len - 1
if len(self._current_episode) >= expected_len:
return [k.state for k in self._current_episode[-expected_len:]]
else:
states = [np.zeros(self.state_shape, dtype=self.dtype)] * (expected_len - len(self._current_episode))
states.extend([k.state for k in self._current_episode])
return states
class ExpReplay(DataFlow, Callback): class ExpReplay(DataFlow, Callback):
""" """
Implement experience replay in the paper Implement experience replay in the paper
...@@ -137,7 +211,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -137,7 +211,6 @@ class ExpReplay(DataFlow, Callback):
state_shape, state_shape,
batch_size, batch_size,
memory_size, init_memory_size, memory_size, init_memory_size,
init_exploration,
update_frequency, history_len, update_frequency, history_len,
state_dtype='uint8'): state_dtype='uint8'):
""" """
...@@ -146,10 +219,14 @@ class ExpReplay(DataFlow, Callback): ...@@ -146,10 +219,14 @@ class ExpReplay(DataFlow, Callback):
predict Q value from state. predict Q value from state.
player (gym.Env): the player. player (gym.Env): the player.
state_shape (tuple): state_shape (tuple):
history_len (int): length of history frames to concat. Zero-filled batch_size (int):
initial frames. memory_size (int):
init_memory_size (int):
update_frequency (int): number of new transitions to add to memory update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training. after sampling a batch of transitions for training.
history_len (int): length of history frames to concat. Zero-filled
initial frames.
state_dtype (str):
""" """
assert len(state_shape) in [1, 2, 3], state_shape assert len(state_shape) in [1, 2, 3], state_shape
init_memory_size = int(init_memory_size) init_memory_size = int(init_memory_size)
...@@ -157,24 +234,21 @@ class ExpReplay(DataFlow, Callback): ...@@ -157,24 +234,21 @@ class ExpReplay(DataFlow, Callback):
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.exploration = init_exploration self.exploration = 1.0 # default initial exploration
self.num_actions = player.action_space.n self.num_actions = player.action_space.n
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized self._init_memory_flag = threading.Event() # tell if memory has been initialized
self.mem = ReplayMemory(memory_size, state_shape, history_len) self.mem = ReplayMemory(memory_size, state_shape, self.history_len, dtype=state_dtype)
self._current_ob = self.player.reset()
self._player_scores = StatCounter()
self._current_game_score = StatCounter()
def _init_memory(self): def _init_memory(self):
logger.info("Populating replay memory with epsilon={} ...".format(self.exploration)) logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))
with get_tqdm(total=self.init_memory_size) as pbar: with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
self._populate_exp() self.env_runner.step(self.exploration)
pbar.update() pbar.update()
self._init_memory_flag.set() self._init_memory_flag.set()
...@@ -183,42 +257,13 @@ class ExpReplay(DataFlow, Callback): ...@@ -183,42 +257,13 @@ class ExpReplay(DataFlow, Callback):
from copy import deepcopy from copy import deepcopy
with get_tqdm(total=self.init_memory_size) as pbar: with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < 5: while len(self.mem) < 5:
self._populate_exp() self.env_runner.step(self.exploration)
pbar.update() pbar.update()
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
self.mem.append(deepcopy(self.mem._hist[0])) self.mem.append(deepcopy(self.mem._hist[0]))
pbar.update() pbar.update()
self._init_memory_flag.set() self._init_memory_flag.set()
def _populate_exp(self):
""" populate a transition by epsilon-greedy"""
old_s = self._current_ob
if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len):
act = self.rng.choice(range(self.num_actions))
else:
# build a history state
history = self.mem.recent_state()
history.append(old_s)
history = np.stack(history, axis=-1) # state_shape + (Hist,)
history = np.expand_dims(history, axis=0)
# assume batched network
q_values = self.predictor(history)[0][0] # this is the bottleneck
act = np.argmax(q_values)
self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward)
if isOver:
if 'ale.lives' in info: # if running Atari, do something special for logging:
if info['ale.lives'] == 0:
# only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
else:
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
self.player.reset()
self.mem.append(Experience(old_s, act, reward, isOver))
def _debug_sample(self, sample): def _debug_sample(self, sample):
import cv2 import cv2
...@@ -257,17 +302,18 @@ class ExpReplay(DataFlow, Callback): ...@@ -257,17 +302,18 @@ class ExpReplay(DataFlow, Callback):
# execute 4 new actions into memory, after each batch update # execute 4 new actions into memory, after each batch update
for _ in range(self.update_frequency): for _ in range(self.update_frequency):
self._populate_exp() self.env_runner.step(self.exploration)
# Callback methods: # Callback methods:
def _setup_graph(self): def _setup_graph(self):
self.predictor = self.trainer.get_predictor(*self.predictor_io_names) predictor = self.trainer.get_predictor(*self.predictor_io_names)
self.env_runner = EnvRunner(self.player, predictor, self.mem, self.history_len)
def _before_train(self): def _before_train(self):
self._init_memory() self._init_memory()
def _trigger(self): def _trigger(self):
v = self._player_scores v = self.env_runner._player_scores
try: try:
mean, max = v.average, v.max mean, max = v.average, v.max
self.trainer.monitors.put_scalar('expreplay/mean_score', mean) self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment