Commit b852d652 authored by Yuxin Wu's avatar Yuxin Wu

reverse bug in expreplay

parent 4c7348c3
...@@ -35,14 +35,13 @@ BATCH_SIZE = 32 ...@@ -35,14 +35,13 @@ BATCH_SIZE = 32
IMAGE_SIZE = 84 IMAGE_SIZE = 84
NUM_ACTIONS = None NUM_ACTIONS = None
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 3 ACTION_REPEAT = 4
#HEIGHT_RANGE = (36, 204) # for breakout HEIGHT_RANGE = (36, 204) # for breakout
HEIGHT_RANGE = (28, -8) # for pong #HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99 GAMMA = 0.99
BATCH_SIZE = 32
INIT_EXPLORATION = 1 INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.0020 EXPLORATION_EPOCH_ANNEAL = 0.008
END_EXPLORATION = 0.1 END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
...@@ -71,8 +70,8 @@ class Model(ModelDesc): ...@@ -71,8 +70,8 @@ class Model(ModelDesc):
l = Conv2D('conv2', l, out_channel=64, kernel_shape=4) l = Conv2D('conv2', l, out_channel=64, kernel_shape=4)
l = MaxPooling('pool2', l, 2) l = MaxPooling('pool2', l, 2)
l = Conv2D('conv3', l, out_channel=64, kernel_shape=3) l = Conv2D('conv3', l, out_channel=64, kernel_shape=3)
l = MaxPooling('pool3', l, 2) #l = MaxPooling('pool3', l, 2)
l = Conv2D('conv4', l, out_channel=64, kernel_shape=3) #l = Conv2D('conv4', l, out_channel=64, kernel_shape=3)
l = FullyConnected('fc0', l, 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name)) l = FullyConnected('fc0', l, 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))
l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity) l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity)
...@@ -90,13 +89,14 @@ class Model(ModelDesc): ...@@ -90,13 +89,14 @@ class Model(ModelDesc):
with tf.variable_scope('target'): with tf.variable_scope('target'):
targetQ_predict_value = tf.stop_gradient( targetQ_predict_value = tf.stop_gradient(
self._get_DQN_prediction(next_state, False)) # NxA self._get_DQN_prediction(next_state, False)) # NxA
target = tf.select(isOver, reward, reward + target = reward + (1 - tf.cast(isOver, tf.int32)) *
GAMMA * tf.reduce_max(targetQ_predict_value, 1)) # Nx1 GAMMA * tf.reduce_max(targetQ_predict_value, 1) # Nx1
sqrcost = tf.square(target - pred_action_value) sqrcost = tf.square(target - pred_action_value)
abscost = tf.abs(target - pred_action_value) # robust error func abscost = tf.abs(target - pred_action_value) # robust error func
cost = tf.select(abscost < 1, sqrcost, abscost) cost = tf.select(abscost < 1, sqrcost, abscost)
summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms']) ]) # monitor all W
self.cost = tf.reduce_mean(cost, name='cost') self.cost = tf.reduce_mean(cost, name='cost')
def update_target_param(self): def update_target_param(self):
......
...@@ -88,6 +88,9 @@ class Callback(object): ...@@ -88,6 +88,9 @@ class Callback(object):
def _trigger_epoch(self): def _trigger_epoch(self):
pass pass
def __str__(self):
return type(self).__name__
class ProxyCallback(Callback): class ProxyCallback(Callback):
def __init__(self, cb): def __init__(self, cb):
self.cb = cb self.cb = cb
...@@ -104,6 +107,9 @@ class ProxyCallback(Callback): ...@@ -104,6 +107,9 @@ class ProxyCallback(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
self.cb.trigger_epoch() self.cb.trigger_epoch()
def __str__(self):
return str(self.cb)
class PeriodicCallback(ProxyCallback): class PeriodicCallback(ProxyCallback):
""" """
A callback to be triggered after every `period` epochs. A callback to be triggered after every `period` epochs.
...@@ -122,3 +128,6 @@ class PeriodicCallback(ProxyCallback): ...@@ -122,3 +128,6 @@ class PeriodicCallback(ProxyCallback):
self.cb.epoch_num = self.epoch_num - 1 self.cb.epoch_num = self.epoch_num - 1
self.cb.trigger_epoch() self.cb.trigger_epoch()
def __str__(self):
return "Periodic-" + str(self.cb)
...@@ -141,8 +141,9 @@ class Callbacks(Callback): ...@@ -141,8 +141,9 @@ class Callbacks(Callback):
test_sess_restored = False test_sess_restored = False
for cb in self.cbs: for cb in self.cbs:
display_name = str(cb)
if isinstance(cb.type, TrainCallbackType): if isinstance(cb.type, TrainCallbackType):
with tm.timed_callback(type(cb).__name__): with tm.timed_callback(display_name):
cb.trigger_epoch() cb.trigger_epoch()
else: else:
if not test_sess_restored: if not test_sess_restored:
...@@ -150,6 +151,6 @@ class Callbacks(Callback): ...@@ -150,6 +151,6 @@ class Callbacks(Callback):
self.test_callback_context.restore_checkpoint() self.test_callback_context.restore_checkpoint()
test_sess_restored = True test_sess_restored = True
with self.test_callback_context.test_context(), \ with self.test_callback_context.test_context(), \
tm.timed_callback(type(cb).__name__): tm.timed_callback(display_name):
cb.trigger_epoch() cb.trigger_epoch()
tm.log() tm.log()
...@@ -159,18 +159,22 @@ class ExpReplay(DataFlow, Callback): ...@@ -159,18 +159,22 @@ class ExpReplay(DataFlow, Callback):
else: else:
# build a history state # build a history state
ss = [old_s] ss = [old_s]
isOver = False
for k in range(1, self.history_len): for k in range(1, self.history_len):
hist_exp = self.mem[-k] hist_exp = self.mem[-k]
if hist_exp.isOver: if hist_exp.isOver:
isOver = True
if isOver:
ss.append(np.zeros_like(ss[0])) ss.append(np.zeros_like(ss[0]))
else: else:
ss.append(hist_exp.state) ss.append(hist_exp.state)
ss.reverse()
ss = np.concatenate(ss, axis=2) ss = np.concatenate(ss, axis=2)
act = np.argmax(self.predictor(ss)) act = np.argmax(self.predictor(ss))
reward, isOver = self.player.action(act) reward, isOver = self.player.action(act)
if self.reward_clip: if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1]) reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
self.mem.append(Experience(old_s, act, reward, isOver)) self.mem.append(Experience(old_s, act, reward, isOver))
def get_data(self): def get_data(self):
...@@ -178,17 +182,18 @@ class ExpReplay(DataFlow, Callback): ...@@ -178,17 +182,18 @@ class ExpReplay(DataFlow, Callback):
while True: while True:
batch_exp = [self.sample_one() for _ in range(self.batch_size)] batch_exp = [self.sample_one() for _ in range(self.batch_size)]
def view_state(state, next_state): #def view_state(state, next_state):
""" for debug state representation""" #""" for debugging state representation"""
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1) #r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1) #r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
print r.shape #r = np.concatenate([r, r2], axis=0)
r = np.concatenate([r, r2], axis=0) #print r.shape
cv2.imshow("state", r) #cv2.imshow("state", r)
cv2.waitKey() #cv2.waitKey()
exp = batch_exp[0] #exp = batch_exp[0]
print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4]) #print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
view_state(exp[0], exp[1]) #if exp[2] or exp[4]:
#view_state(exp[0], exp[1])
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
for _ in range(self.new_experience_per_step): for _ in range(self.new_experience_per_step):
...@@ -247,13 +252,12 @@ class ExpReplay(DataFlow, Callback): ...@@ -247,13 +252,12 @@ class ExpReplay(DataFlow, Callback):
self.trainer.write_scalar_summary('expreplay/' + k, v) self.trainer.write_scalar_summary('expreplay/' + k, v)
self.player.reset_stat() self.player.reset_stat()
if __name__ == '__main__': if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariPlayer from tensorpack.dataflow.dataset import AtariPlayer
import sys import sys
predictor = lambda x: np.array([1,1,1,1]) predictor = lambda x: np.array([1,1,1,1])
predictor.initialized = False predictor.initialized = False
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=20) player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor, E = ExpReplay(predictor,
player=player, player=player,
num_actions=player.get_num_actions(), num_actions=player.get_num_actions(),
...@@ -262,6 +266,8 @@ if __name__ == '__main__': ...@@ -262,6 +266,8 @@ if __name__ == '__main__':
E.init_memory() E.init_memory()
for k in E.get_data(): for k in E.get_data():
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
pass pass
#import IPython; #import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config()) #IPython.embed(config=IPython.terminal.ipapp.load_default_config())
......
...@@ -73,7 +73,9 @@ def add_param_summary(summary_lists): ...@@ -73,7 +73,9 @@ def add_param_summary(summary_lists):
for p in params: for p in params:
name = p.name name = p.name
for rgx, actions in summary_lists: for rgx, actions in summary_lists:
if re.search(rgx, name): if not rgx.endswith('$'):
rgx = rgx + '$'
if re.match(rgx, name):
for act in actions: for act in actions:
perform(p, act) perform(p, act)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment