Commit e5a48033 authored by Yuxin Wu's avatar Yuxin Wu

speedup expreplay a bit

parent 364fe347
......@@ -133,8 +133,8 @@ class Model(ModelDesc):
SummaryGradient()]
def predictor(self, state):
# TODO change to a multitower predictor for speedup
return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
#return self.predict_value.eval(feed_dict={'input_deque:0': [state]})[0]
def get_config():
basename = os.path.basename(__file__)
......@@ -206,4 +206,5 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
#QueueInputTrainer(config).train()
# TODO test if QueueInput affects learning
......@@ -107,7 +107,7 @@ class AtariPlayer(RLEnvironment):
def current_state(self):
"""
:returns: a gray-scale (h, w, 1) image
:returns: a gray-scale (h, w, 1) float32 image
"""
ret = self._grab_raw_image()
# max-pooled over the last screen
......@@ -117,7 +117,7 @@ class AtariPlayer(RLEnvironment):
#m = cv2.resize(ret, (1920,1200))
cv2.imshow(self.windowname, ret)
time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1],:]
ret = ret[self.height_range[0]:self.height_range[1],:].astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape)
......
......@@ -5,6 +5,7 @@
import numpy as np
from collections import deque, namedtuple
import threading
from tqdm import tqdm
import six
......@@ -48,6 +49,7 @@ class ExpReplay(DataFlow, Callback):
if populate_size is not None:
logger.warn("populate_size in ExpReplay is deprecated in favor of init_memory_size")
init_memory_size = populate_size
init_memory_size = int(init_memory_size)
for k, v in locals().items():
if k != 'self':
......@@ -56,6 +58,7 @@ class ExpReplay(DataFlow, Callback):
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self)
self._init_memory_flag = threading.Event()
def _init_memory(self):
logger.info("Populating replay memory...")
......@@ -69,13 +72,17 @@ class ExpReplay(DataFlow, Callback):
with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
self._populate_exp()
from copy import deepcopy
self.mem.append(deepcopy(self.mem[0]))
#self._populate_exp()
pbar.update()
self._init_memory_flag.set()
def reset_state(self):
raise RuntimeError("Don't run me in multiple processes")
def _populate_exp(self):
""" populate a transition by epsilon-greedy"""
old_s = self.player.current_state()
if self.rng.rand() <= self.exploration:
act = self.rng.choice(range(self.num_actions))
......@@ -101,6 +108,7 @@ class ExpReplay(DataFlow, Callback):
self.mem.append(Experience(old_s, act, reward, isOver))
def get_data(self):
self._init_memory_flag.wait()
# new s is considered useless if isOver==True
while True:
batch_exp = [self._sample_one() for _ in range(self.batch_size)]
......@@ -125,27 +133,28 @@ class ExpReplay(DataFlow, Callback):
def _sample_one(self):
""" return the transition tuple for
[idx, idx+history_len] -> [idx+1, idx+1+history_len]
[idx, idx+history_len) -> [idx+1, idx+1+history_len)
it's the transition from state idx+history_len-1 to state idx+history_len
"""
# look for a state to start with
# when x.isOver==True, (x+1).state is of a different episode
idx = self.rng.randint(len(self.mem) - self.history_len - 1)
start_idx = idx + self.history_len - 1
samples = [self.mem[k] for k in range(idx, idx+self.history_len+1)]
def concat(idx):
v = [self.mem[x].state for x in range(idx, idx+self.history_len)]
v = [x.state for x in samples[idx:idx+self.history_len]]
return np.concatenate(v, axis=2)
state = concat(idx)
next_state = concat(idx + 1)
reward = self.mem[start_idx].reward
action = self.mem[start_idx].action
isOver = self.mem[start_idx].isOver
state = concat(0)
next_state = concat(1)
start_mem = samples[-2]
reward, action, isOver = start_mem.reward, start_mem.action, start_mem.isOver
start_idx = self.history_len - 1
# zero-fill state before starting
zero_fill = False
for k in range(1, self.history_len):
if self.mem[start_idx-k].isOver:
if samples[start_idx-k].isOver:
zero_fill = True
if zero_fill:
state[:,:,-k-1] = 0
......@@ -157,8 +166,8 @@ class ExpReplay(DataFlow, Callback):
state = np.array([e[0] for e in batch_exp])
next_state = np.array([e[1] for e in batch_exp])
reward = np.array([e[2] for e in batch_exp])
action = np.array([e[3] for e in batch_exp])
isOver = np.array([e[4] for e in batch_exp])
action = np.array([e[3] for e in batch_exp], dtype='int8')
isOver = np.array([e[4] for e in batch_exp], dtype='bool')
return [state, action, reward, next_state, isOver]
# Callback-related:
......
......@@ -84,8 +84,7 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
#_, size = self.sess.run([self.op, self.size_op], feed_dict=feed)
#print size
#print self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
......@@ -144,7 +143,7 @@ class QueueInputTrainer(Trainer):
def _single_tower_grad(self):
""" Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_model_inputs()
self.model.build_graph(model_inputs, True)
self.model.build_graph(self.dequed_inputs, True)
cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(cost_var)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment