Commit ec726f6c authored by Yuxin Wu's avatar Yuxin Wu

clean-ups in DQN

parent da9b1b2f
......@@ -30,8 +30,9 @@ from tensorpack.RL import *
"""
Implement DQN in:
Human-level control through deep reinforcement learning
for atari games
Human-level Control Through Deep Reinforcement Learning
for atari games. Use the variants in:
Deep Reinforcement Learning with Double Q-learning.
"""
BATCH_SIZE = 32
......@@ -139,10 +140,8 @@ class Model(ModelDesc):
tf.clip_by_global_norm([grad], 5)[0][0]),
SummaryGradient()]
def current_predictor(state):
pred_var = tf.get_default_graph().get_tensor_by_name('fct/output:0')
pred = pred_var.eval(feed_dict={'state:0': [state]})
return pred[0]
def predictor(self, state):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def play_one_episode(player, func, verbose=False):
while True:
......@@ -237,7 +236,7 @@ def get_config():
M = Model()
dataset_train = ExpReplay(
predictor=current_predictor,
predictor=M.predictor,
player=get_player(train=True),
num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE,
......@@ -246,6 +245,7 @@ def get_config():
exploration=INIT_EXPLORATION,
end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
update_frequency=4,
reward_clip=(-1, 1),
history_len=FRAME_HISTORY)
......
......@@ -5,7 +5,6 @@
import multiprocessing
import threading
import zmq
import weakref
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
......@@ -15,6 +14,13 @@ from tensorpack.utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster']
try:
import zmq
except ImportError:
logger.warn("Error in 'import zmq'. RL simulator won't be available.")
__all__ = []
class SimulatorProcess(multiprocessing.Process):
""" A process that simulates a player """
__metaclass__ = ABCMeta
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment