Commit 6452b444 authored by Yuxin Wu's avatar Yuxin Wu

update DQN code

parent 0870401c
...@@ -43,7 +43,7 @@ MEMORY_SIZE = 1e6 ...@@ -43,7 +43,7 @@ MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory. # NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
# Suggest using tcmalloc to manage memory space better. # Suggest using tcmalloc to manage memory space better.
INIT_MEMORY_SIZE = 5e4 INIT_MEMORY_SIZE = 5e4
STEP_PER_EPOCH = 10000 STEPS_PER_EPOCH = 10000
EVAL_EPISODE = 50 EVAL_EPISODE = 50
NUM_ACTIONS = None NUM_ACTIONS = None
...@@ -54,8 +54,6 @@ METHOD = None ...@@ -54,8 +54,6 @@ METHOD = None
def get_player(viz=False, train=False): def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train) image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train)
global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions()
if not train: if not train:
pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis]) pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis])
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
...@@ -69,9 +67,8 @@ common.get_player = get_player # so that eval functions in common can use the p ...@@ -69,9 +67,8 @@ common.get_player = get_player # so that eval functions in common can use the p
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
if NUM_ACTIONS is None: # use a combined state, where the first channels are the current state,
p = get_player() # and the last 4 channels are the next state
del p
return [InputDesc(tf.uint8, return [InputDesc(tf.uint8,
(None,) + IMAGE_SIZE + (CHANNEL + 1,), (None,) + IMAGE_SIZE + (CHANNEL + 1,),
'comb_state'), 'comb_state'),
...@@ -102,28 +99,31 @@ class Model(ModelDesc): ...@@ -102,28 +99,31 @@ class Model(ModelDesc):
if METHOD != 'Dueling': if METHOD != 'Dueling':
Q = FullyConnected('fct', l, NUM_ACTIONS, nl=tf.identity) Q = FullyConnected('fct', l, NUM_ACTIONS, nl=tf.identity)
else: else:
# Dueling DQN
V = FullyConnected('fctV', l, 1, nl=tf.identity) V = FullyConnected('fctV', l, 1, nl=tf.identity)
As = FullyConnected('fctA', l, NUM_ACTIONS, nl=tf.identity) As = FullyConnected('fctA', l, NUM_ACTIONS, nl=tf.identity)
Q = tf.add(As, V - tf.reduce_mean(As, 1, keep_dims=True)) Q = tf.add(As, V - tf.reduce_mean(As, 1, keep_dims=True))
return tf.identity(Q, name='Qvalue') return tf.identity(Q, name='Qvalue')
def _build_graph(self, inputs): def _build_graph(self, inputs):
ctx = get_current_tower_context()
comb_state, action, reward, isOver = inputs comb_state, action, reward, isOver = inputs
comb_state = tf.cast(comb_state, tf.float32) comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, 4], name='state') state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, 4], name='state')
self.predict_value = self._get_DQN_prediction(state) self.predict_value = self._get_DQN_prediction(state)
if not ctx.is_training: if not get_current_tower_context().is_training:
return return
reward = tf.clip_by_value(reward, -1, 1)
next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, 4], name='next_state') next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, 4], name='next_state')
action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0) action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N, pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
max_pred_reward = tf.reduce_mean(tf.reduce_max( max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward) summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'): with tf.variable_scope('target'), \
collection.freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]):
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
if METHOD != 'Double': if METHOD != 'Double':
...@@ -146,17 +146,6 @@ class Model(ModelDesc): ...@@ -146,17 +146,6 @@ class Model(ModelDesc):
('fc.*/W', ['histogram', 'rms'])) # monitor all W ('fc.*/W', ['histogram', 'rms'])) # monitor all W
summary.add_moving_summary(self.cost) summary.add_moving_summary(self.cost)
def update_target_param(self):
vars = tf.trainable_variables()
ops = []
for v in vars:
target_name = v.op.name
if target_name.startswith('target'):
new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name))
ops.append(v.assign(tf.get_default_graph().get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network')
def _get_optimizer(self): def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 1e-3, summary=True) lr = symbf.get_scalar_var('learning_rate', 1e-3, summary=True)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3) opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
...@@ -179,25 +168,36 @@ def get_config(): ...@@ -179,25 +168,36 @@ def get_config():
end_exploration=END_EXPLORATION, end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL, exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
update_frequency=4, update_frequency=4,
history_len=FRAME_HISTORY, history_len=FRAME_HISTORY
reward_clip=(-1, 1)
) )
def update_target_param():
vars = tf.trainable_variables()
ops = []
G = tf.get_default_graph()
for v in vars:
target_name = v.op.name
if target_name.startswith('target'):
new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network')
return TrainConfig( return TrainConfig(
dataflow=expreplay, dataflow=expreplay,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]), [(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
RunOp(lambda: M.update_target_param()), RunOp(update_target_param),
expreplay, expreplay,
StartProcOrThread(expreplay.get_simulator_thread()), PeriodicTrigger(Evaluator(
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue']), 3), EVAL_EPISODE, ['state'], ['Qvalue']), every_k_epochs=5),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'), # HumanHyperParamSetter('learning_rate', 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'), # HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
], ],
model=M, model=M,
steps_per_epoch=STEP_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
# run the simulator on a separate GPU if available # run the simulator on a separate GPU if available
predict_tower=[1] if get_nr_gpu() > 1 else [0], predict_tower=[1] if get_nr_gpu() > 1 else [0],
) )
...@@ -221,6 +221,11 @@ if __name__ == '__main__': ...@@ -221,6 +221,11 @@ if __name__ == '__main__':
ROM_FILE = args.rom ROM_FILE = args.rom
METHOD = args.algo METHOD = args.algo
# set num_actions
pl = AtariPlayer(ROM_FILE, viz=False)
NUM_ACTIONS = pl.get_action_space().num_actions()
del pl
if args.task != 'train': if args.task != 'train':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
......
...@@ -19,12 +19,11 @@ Claimed performance in the paper can be reproduced, on several games I've tested ...@@ -19,12 +19,11 @@ Claimed performance in the paper can be reproduced, on several games I've tested
![DQN](curve-breakout.png) ![DQN](curve-breakout.png)
DQN typically took 1.5 days of training to reach a score of 400 on breakout game (same as the paper). DQN typically took 1 day of training to reach a score of 400 on breakout game (same as the paper).
My Batch-A3C implementation only took <2 hours. My Batch-A3C implementation only took <2 hours.
Both were trained on one GPU with an extra GPU for simulation. Both were trained on one GPU with an extra GPU for simulation.
The x-axis is the number of iterations, not wall time. Double-DQN runs at 18 batches/s (1152 frames/s) on TitanX.
Double-DQN is faster at the beginning but will converge to 12 batches/s (768 frames/s) due of exploration annealing.
## How to use ## How to use
...@@ -37,9 +36,10 @@ To train: ...@@ -37,9 +36,10 @@ To train:
# use `--algo` to select other DQN algorithms. See `-h` for more options. # use `--algo` to select other DQN algorithms. See `-h` for more options.
``` ```
To visualize the agent: To watch the agent play:
``` ```
./DQN.py --rom breakout.bin --task play --load trained.model ./DQN.py --rom breakout.bin --task play --load trained.model
``` ```
A pretrained model on breakout can be downloaded [here](https://drive.google.com/open?id=0B9IPQTvr2BBkN1Jrei1xWW0yR28).
A3C code and models for Atari games in OpenAI Gym are released in [examples/A3C-Gym](../A3C-Gym) A3C code and models for Atari games in OpenAI Gym are released in [examples/A3C-Gym](../A3C-Gym)
...@@ -11,7 +11,6 @@ from tqdm import tqdm ...@@ -11,7 +11,6 @@ from tqdm import tqdm
from six.moves import queue from six.moves import queue
from tensorpack import * from tensorpack import *
from tensorpack.predict import get_predict_func
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.utils.stats import * from tensorpack.utils.stats import *
...@@ -33,7 +32,7 @@ def play_one_episode(player, func, verbose=False): ...@@ -33,7 +32,7 @@ def play_one_episode(player, func, verbose=False):
def play_model(cfg): def play_model(cfg):
player = get_player(viz=0.01) player = get_player(viz=0.01)
predfunc = get_predict_func(cfg) predfunc = OfflinePredictor(cfg)
while True: while True:
score = play_one_episode(player, predfunc) score = play_one_episode(player, predfunc)
print("Total:", score) print("Total:", score)
...@@ -96,7 +95,7 @@ def eval_model_multithread(cfg, nr_eval): ...@@ -96,7 +95,7 @@ def eval_model_multithread(cfg, nr_eval):
logger.info("Average Score: {}; Max Score: {}".format(mean, max)) logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback): class Evaluator(Triggerable):
def __init__(self, nr_eval, input_names, output_names): def __init__(self, nr_eval, input_names, output_names):
self.eval_episode = nr_eval self.eval_episode = nr_eval
self.input_names = input_names self.input_names = input_names
...@@ -107,7 +106,7 @@ class Evaluator(Callback): ...@@ -107,7 +106,7 @@ class Evaluator(Callback):
self.pred_funcs = [self.trainer.get_predict_func( self.pred_funcs = [self.trainer.get_predict_func(
self.input_names, self.output_names)] * NR_PROC self.input_names, self.output_names)] * NR_PROC
def _trigger_epoch(self): def _trigger(self):
t = time.time() t = time.time()
mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode) mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode)
t = time.time() - t t = time.time() - t
......
...@@ -124,8 +124,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -124,8 +124,7 @@ class ExpReplay(DataFlow, Callback):
batch_size, batch_size,
memory_size, init_memory_size, memory_size, init_memory_size,
exploration, end_exploration, exploration_epoch_anneal, exploration, end_exploration, exploration_epoch_anneal,
update_frequency, history_len, update_frequency, history_len):
reward_clip=None):
""" """
Args: Args:
predictor_io_names (tuple of list of str): input/output names to predictor_io_names (tuple of list of str): input/output names to
...@@ -191,8 +190,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -191,8 +190,6 @@ class ExpReplay(DataFlow, Callback):
q_values = self.predictor([[history]])[0][0] q_values = self.predictor([[history]])[0][0]
act = np.argmax(q_values) act = np.argmax(q_values)
reward, isOver = self.player.action(act) reward, isOver = self.player.action(act)
if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
self.mem.append(Experience(old_s, act, reward, isOver)) self.mem.append(Experience(old_s, act, reward, isOver))
def debug_sample(self, sample): def debug_sample(self, sample):
...@@ -236,7 +233,8 @@ class ExpReplay(DataFlow, Callback): ...@@ -236,7 +233,8 @@ class ExpReplay(DataFlow, Callback):
def _before_train(self): def _before_train(self):
self._init_memory() self._init_memory()
# TODO start thread here self._simulator_th = self.get_simulator_thread()
self._simulator_th.start()
def _trigger_epoch(self): def _trigger_epoch(self):
if self.exploration > self.end_exploration: if self.exploration > self.end_exploration:
......
...@@ -4,11 +4,20 @@ ...@@ -4,11 +4,20 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import numpy as np
import six
import sys import sys
import pprint import pprint
from tensorpack.tfutils.varmanip import get_checkpoint_path from tensorpack.tfutils.varmanip import get_checkpoint_path
path = get_checkpoint_path(sys.argv[1]) fpath = sys.argv[1]
reader = tf.train.NewCheckpointReader(path)
pprint.pprint(reader.get_variable_to_shape_map()) if fpath.endswith('.npy'):
params = np.load(fpath, encoding='latin1').item()
dic = {k: v.shape for k, v in six.iteritems(params)}
else:
path = get_checkpoint_path(sys.argv[1])
reader = tf.train.NewCheckpointReader(path)
dic = reader.get_variable_to_shape_map()
pprint.pprint(dic)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment