Commit fefdcfb1 authored by Yuxin Wu's avatar Yuxin Wu

use get_predict_func for DQN

parent bbc17cb1
...@@ -134,15 +134,12 @@ class Model(ModelDesc): ...@@ -134,15 +134,12 @@ class Model(ModelDesc):
tf.clip_by_global_norm([grad], 5)[0][0]), tf.clip_by_global_norm([grad], 5)[0][0]),
SummaryGradient()] SummaryGradient()]
def predictor(self, state):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir()
M = Model() M = Model()
dataset_train = ExpReplay( dataset_train = ExpReplay(
predictor=M.predictor, predictor_io_names=(['state'], ['fct/output']),
player=get_player(train=True), player=get_player(train=True),
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
......
...@@ -29,7 +29,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -29,7 +29,7 @@ class ExpReplay(DataFlow, Callback):
This DataFlow is not fork-safe (doesn't support multiprocess prefetching) This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
""" """
def __init__(self, def __init__(self,
predictor, predictor_io_names,
player, player,
batch_size=32, batch_size=32,
memory_size=1e6, memory_size=1e6,
...@@ -64,6 +64,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -64,6 +64,7 @@ class ExpReplay(DataFlow, Callback):
self.mem = deque(maxlen=memory_size) self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized self._init_memory_flag = threading.Event() # tell if memory has been initialized
self._predictor_io_names = predictor_io_names
def _init_memory(self): def _init_memory(self):
logger.info("Populating replay memory...") logger.info("Populating replay memory...")
...@@ -90,6 +91,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -90,6 +91,7 @@ class ExpReplay(DataFlow, Callback):
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
else: else:
# build a history state # build a history state
# XXX assume a state can be representated by one tensor
ss = [old_s] ss = [old_s]
isOver = False isOver = False
...@@ -103,7 +105,9 @@ class ExpReplay(DataFlow, Callback): ...@@ -103,7 +105,9 @@ class ExpReplay(DataFlow, Callback):
ss.append(hist_exp.state) ss.append(hist_exp.state)
ss.reverse() ss.reverse()
ss = np.concatenate(ss, axis=2) ss = np.concatenate(ss, axis=2)
act = np.argmax(self.predictor(ss)) # XXX assume batched network
q_values = self.predictor([[ss]])[0][0]
act = np.argmax(q_values)
reward, isOver = self.player.action(act) reward, isOver = self.player.action(act)
if self.reward_clip: if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1]) reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
...@@ -171,6 +175,9 @@ class ExpReplay(DataFlow, Callback): ...@@ -171,6 +175,9 @@ class ExpReplay(DataFlow, Callback):
isOver = np.array([e[4] for e in batch_exp], dtype='bool') isOver = np.array([e[4] for e in batch_exp], dtype='bool')
return [state, action, reward, next_state, isOver] return [state, action, reward, next_state, isOver]
def _setup_graph(self):
self.predictor = self.trainer.get_predict_func(*self._predictor_io_names)
# Callback-related: # Callback-related:
def _before_train(self): def _before_train(self):
# spawn a separate thread to run policy, can speed up 1.3x # spawn a separate thread to run policy, can speed up 1.3x
...@@ -204,7 +211,6 @@ if __name__ == '__main__': ...@@ -204,7 +211,6 @@ if __name__ == '__main__':
from .atari import AtariPlayer from .atari import AtariPlayer
import sys import sys
predictor = lambda x: np.array([1,1,1,1]) predictor = lambda x: np.array([1,1,1,1])
predictor.initialized = False
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204)) player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor, E = ExpReplay(predictor,
player=player, player=player,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment