Commit fefdcfb1 authored by Yuxin Wu's avatar Yuxin Wu

use get_predict_func for DQN

parent bbc17cb1
......@@ -134,15 +134,12 @@ class Model(ModelDesc):
tf.clip_by_global_norm([grad], 5)[0][0]),
SummaryGradient()]
def predictor(self, state):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def get_config():
logger.auto_set_dir()
M = Model()
dataset_train = ExpReplay(
predictor=M.predictor,
predictor_io_names=(['state'], ['fct/output']),
player=get_player(train=True),
batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE,
......
......@@ -29,7 +29,7 @@ class ExpReplay(DataFlow, Callback):
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
def __init__(self,
predictor,
predictor_io_names,
player,
batch_size=32,
memory_size=1e6,
......@@ -64,6 +64,7 @@ class ExpReplay(DataFlow, Callback):
self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
self._predictor_io_names = predictor_io_names
def _init_memory(self):
logger.info("Populating replay memory...")
......@@ -90,6 +91,7 @@ class ExpReplay(DataFlow, Callback):
act = self.rng.choice(range(self.num_actions))
else:
# build a history state
# XXX assume a state can be representated by one tensor
ss = [old_s]
isOver = False
......@@ -103,7 +105,9 @@ class ExpReplay(DataFlow, Callback):
ss.append(hist_exp.state)
ss.reverse()
ss = np.concatenate(ss, axis=2)
act = np.argmax(self.predictor(ss))
# XXX assume batched network
q_values = self.predictor([[ss]])[0][0]
act = np.argmax(q_values)
reward, isOver = self.player.action(act)
if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
......@@ -171,6 +175,9 @@ class ExpReplay(DataFlow, Callback):
isOver = np.array([e[4] for e in batch_exp], dtype='bool')
return [state, action, reward, next_state, isOver]
def _setup_graph(self):
self.predictor = self.trainer.get_predict_func(*self._predictor_io_names)
# Callback-related:
def _before_train(self):
# spawn a separate thread to run policy, can speed up 1.3x
......@@ -204,7 +211,6 @@ if __name__ == '__main__':
from .atari import AtariPlayer
import sys
predictor = lambda x: np.array([1,1,1,1])
predictor.initialized = False
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor,
player=player,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment