Commit e5a48033 authored by Yuxin Wu's avatar Yuxin Wu

speedup expreplay a bit

parent 364fe347
...@@ -133,8 +133,8 @@ class Model(ModelDesc): ...@@ -133,8 +133,8 @@ class Model(ModelDesc):
SummaryGradient()] SummaryGradient()]
def predictor(self, state): def predictor(self, state):
# TODO change to a multitower predictor for speedup
return self.predict_value.eval(feed_dict={'state:0': [state]})[0] return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
#return self.predict_value.eval(feed_dict={'input_deque:0': [state]})[0]
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
...@@ -206,4 +206,5 @@ if __name__ == '__main__': ...@@ -206,4 +206,5 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train() SimpleTrainer(config).train()
#QueueInputTrainer(config).train() #QueueInputTrainer(config).train()
# TODO test if QueueInput affects learning
...@@ -107,7 +107,7 @@ class AtariPlayer(RLEnvironment): ...@@ -107,7 +107,7 @@ class AtariPlayer(RLEnvironment):
def current_state(self): def current_state(self):
""" """
:returns: a gray-scale (h, w, 1) image :returns: a gray-scale (h, w, 1) float32 image
""" """
ret = self._grab_raw_image() ret = self._grab_raw_image()
# max-pooled over the last screen # max-pooled over the last screen
...@@ -117,7 +117,7 @@ class AtariPlayer(RLEnvironment): ...@@ -117,7 +117,7 @@ class AtariPlayer(RLEnvironment):
#m = cv2.resize(ret, (1920,1200)) #m = cv2.resize(ret, (1920,1200))
cv2.imshow(self.windowname, ret) cv2.imshow(self.windowname, ret)
time.sleep(self.viz) time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1],:] ret = ret[self.height_range[0]:self.height_range[1],:].astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image # 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape) ret = cv2.resize(ret, self.image_shape)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import numpy as np import numpy as np
from collections import deque, namedtuple from collections import deque, namedtuple
import threading
from tqdm import tqdm from tqdm import tqdm
import six import six
...@@ -48,6 +49,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -48,6 +49,7 @@ class ExpReplay(DataFlow, Callback):
if populate_size is not None: if populate_size is not None:
logger.warn("populate_size in ExpReplay is deprecated in favor of init_memory_size") logger.warn("populate_size in ExpReplay is deprecated in favor of init_memory_size")
init_memory_size = populate_size init_memory_size = populate_size
init_memory_size = int(init_memory_size)
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
...@@ -56,6 +58,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -56,6 +58,7 @@ class ExpReplay(DataFlow, Callback):
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size) self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event()
def _init_memory(self): def _init_memory(self):
logger.info("Populating replay memory...") logger.info("Populating replay memory...")
...@@ -69,13 +72,17 @@ class ExpReplay(DataFlow, Callback): ...@@ -69,13 +72,17 @@ class ExpReplay(DataFlow, Callback):
with tqdm(total=self.init_memory_size) as pbar: with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
self._populate_exp() from copy import deepcopy
self.mem.append(deepcopy(self.mem[0]))
#self._populate_exp()
pbar.update() pbar.update()
self._init_memory_flag.set()
def reset_state(self): def reset_state(self):
raise RuntimeError("Don't run me in multiple processes") raise RuntimeError("Don't run me in multiple processes")
def _populate_exp(self): def _populate_exp(self):
""" populate a transition by epsilon-greedy"""
old_s = self.player.current_state() old_s = self.player.current_state()
if self.rng.rand() <= self.exploration: if self.rng.rand() <= self.exploration:
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
...@@ -101,6 +108,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -101,6 +108,7 @@ class ExpReplay(DataFlow, Callback):
self.mem.append(Experience(old_s, act, reward, isOver)) self.mem.append(Experience(old_s, act, reward, isOver))
def get_data(self): def get_data(self):
self._init_memory_flag.wait()
# new s is considered useless if isOver==True # new s is considered useless if isOver==True
while True: while True:
batch_exp = [self._sample_one() for _ in range(self.batch_size)] batch_exp = [self._sample_one() for _ in range(self.batch_size)]
...@@ -125,27 +133,28 @@ class ExpReplay(DataFlow, Callback): ...@@ -125,27 +133,28 @@ class ExpReplay(DataFlow, Callback):
def _sample_one(self): def _sample_one(self):
""" return the transition tuple for """ return the transition tuple for
[idx, idx+history_len] -> [idx+1, idx+1+history_len] [idx, idx+history_len) -> [idx+1, idx+1+history_len)
it's the transition from state idx+history_len-1 to state idx+history_len it's the transition from state idx+history_len-1 to state idx+history_len
""" """
# look for a state to start with # look for a state to start with
# when x.isOver==True, (x+1).state is of a different episode # when x.isOver==True, (x+1).state is of a different episode
idx = self.rng.randint(len(self.mem) - self.history_len - 1) idx = self.rng.randint(len(self.mem) - self.history_len - 1)
start_idx = idx + self.history_len - 1
samples = [self.mem[k] for k in range(idx, idx+self.history_len+1)]
def concat(idx): def concat(idx):
v = [self.mem[x].state for x in range(idx, idx+self.history_len)] v = [x.state for x in samples[idx:idx+self.history_len]]
return np.concatenate(v, axis=2) return np.concatenate(v, axis=2)
state = concat(idx) state = concat(0)
next_state = concat(idx + 1) next_state = concat(1)
reward = self.mem[start_idx].reward start_mem = samples[-2]
action = self.mem[start_idx].action reward, action, isOver = start_mem.reward, start_mem.action, start_mem.isOver
isOver = self.mem[start_idx].isOver
start_idx = self.history_len - 1
# zero-fill state before starting # zero-fill state before starting
zero_fill = False zero_fill = False
for k in range(1, self.history_len): for k in range(1, self.history_len):
if self.mem[start_idx-k].isOver: if samples[start_idx-k].isOver:
zero_fill = True zero_fill = True
if zero_fill: if zero_fill:
state[:,:,-k-1] = 0 state[:,:,-k-1] = 0
...@@ -157,8 +166,8 @@ class ExpReplay(DataFlow, Callback): ...@@ -157,8 +166,8 @@ class ExpReplay(DataFlow, Callback):
state = np.array([e[0] for e in batch_exp]) state = np.array([e[0] for e in batch_exp])
next_state = np.array([e[1] for e in batch_exp]) next_state = np.array([e[1] for e in batch_exp])
reward = np.array([e[2] for e in batch_exp]) reward = np.array([e[2] for e in batch_exp])
action = np.array([e[3] for e in batch_exp]) action = np.array([e[3] for e in batch_exp], dtype='int8')
isOver = np.array([e[4] for e in batch_exp]) isOver = np.array([e[4] for e in batch_exp], dtype='bool')
return [state, action, reward, next_state, isOver] return [state, action, reward, next_state, isOver]
# Callback-related: # Callback-related:
......
...@@ -84,8 +84,7 @@ class EnqueueThread(threading.Thread): ...@@ -84,8 +84,7 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
#_, size = self.sess.run([self.op, self.size_op], feed_dict=feed) #print self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
#print size
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
...@@ -144,7 +143,7 @@ class QueueInputTrainer(Trainer): ...@@ -144,7 +143,7 @@ class QueueInputTrainer(Trainer):
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower""" """ Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_model_inputs() self.dequed_inputs = model_inputs = self._get_model_inputs()
self.model.build_graph(model_inputs, True) self.model.build_graph(self.dequed_inputs, True)
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment