Commit ec726f6c authored by Yuxin Wu's avatar Yuxin Wu

clean-ups in DQN

parent da9b1b2f
...@@ -30,8 +30,9 @@ from tensorpack.RL import * ...@@ -30,8 +30,9 @@ from tensorpack.RL import *
""" """
Implement DQN in: Implement DQN in:
Human-level control through deep reinforcement learning Human-level Control Through Deep Reinforcement Learning
for atari games for atari games. Use the variants in:
Deep Reinforcement Learning with Double Q-learning.
""" """
BATCH_SIZE = 32 BATCH_SIZE = 32
...@@ -139,10 +140,8 @@ class Model(ModelDesc): ...@@ -139,10 +140,8 @@ class Model(ModelDesc):
tf.clip_by_global_norm([grad], 5)[0][0]), tf.clip_by_global_norm([grad], 5)[0][0]),
SummaryGradient()] SummaryGradient()]
def current_predictor(state): def predictor(self, state):
pred_var = tf.get_default_graph().get_tensor_by_name('fct/output:0') return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
pred = pred_var.eval(feed_dict={'state:0': [state]})
return pred[0]
def play_one_episode(player, func, verbose=False): def play_one_episode(player, func, verbose=False):
while True: while True:
...@@ -237,7 +236,7 @@ def get_config(): ...@@ -237,7 +236,7 @@ def get_config():
M = Model() M = Model()
dataset_train = ExpReplay( dataset_train = ExpReplay(
predictor=current_predictor, predictor=M.predictor,
player=get_player(train=True), player=get_player(train=True),
num_actions=NUM_ACTIONS, num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
...@@ -246,6 +245,7 @@ def get_config(): ...@@ -246,6 +245,7 @@ def get_config():
exploration=INIT_EXPLORATION, exploration=INIT_EXPLORATION,
end_exploration=END_EXPLORATION, end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL, exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
update_frequency=4,
reward_clip=(-1, 1), reward_clip=(-1, 1),
history_len=FRAME_HISTORY) history_len=FRAME_HISTORY)
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import multiprocessing import multiprocessing
import threading import threading
import zmq
import weakref import weakref
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
...@@ -15,6 +14,13 @@ from tensorpack.utils.concurrency import * ...@@ -15,6 +14,13 @@ from tensorpack.utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster'] __all__ = ['SimulatorProcess', 'SimulatorMaster']
try:
import zmq
except ImportError:
logger.warn("Error in 'import zmq'. RL simulator won't be available.")
__all__ = []
class SimulatorProcess(multiprocessing.Process): class SimulatorProcess(multiprocessing.Process):
""" A process that simulates a player """ """ A process that simulates a player """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment