Commit 4414d3ba authored by Yuxin Wu's avatar Yuxin Wu

refactor RL a bit

parent 249052e0
......@@ -71,9 +71,6 @@ def get_player(viz=False, train=False, dumpdir=None):
return pl
common.get_player = get_player
class MySimulatorWorker(SimulatorProcess):
def _build_player(self):
return get_player(train=True)
......@@ -230,7 +227,9 @@ def get_config():
HumanHyperParamSetter('entropy_beta'),
master,
StartProcOrThread(master),
PeriodicTrigger(Evaluator(EVAL_EPISODE, ['state'], ['policy']), every_k_epochs=2),
PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['policy'], get_player),
every_k_epochs=2),
],
session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
......@@ -280,9 +279,9 @@ if __name__ == '__main__':
input_names=['state'],
output_names=['policy'])
if args.task == 'play':
play_model(cfg)
play_model(cfg, get_player(viz=0.01))
elif args.task == 'eval':
eval_model_multithread(cfg, args.episode)
eval_model_multithread(cfg, args.episode, get_player)
elif args.task == 'gen_submit':
run_submission(cfg, args.output, args.episode)
else:
......
......@@ -18,10 +18,10 @@ from collections import deque
from tensorpack import *
from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.RL import *
import tensorflow as tf
from DQNModel import Model as DQNModel
import common
from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer
......@@ -61,20 +61,7 @@ def get_player(viz=False, train=False):
return pl
common.get_player = get_player # so that eval functions in common can use the player
class Model(ModelDesc):
def _get_inputs(self):
# use a combined state, where the first channels are the current state,
# and the last 4 channels are the next state
return [InputDesc(tf.uint8,
(None,) + IMAGE_SIZE + (CHANNEL + 1,),
'comb_state'),
InputDesc(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'reward'),
InputDesc(tf.bool, (None,), 'isOver')]
class Model(DQNModel):
def _get_DQN_prediction(self, image):
""" image: [0,255]"""
image = image / 255.0
......@@ -95,67 +82,20 @@ class Model(ModelDesc):
# .Conv2D('conv2', out_channel=64, kernel_shape=3)
.FullyConnected('fc0', 512, nl=LeakyReLU)())
if METHOD != 'Dueling':
Q = FullyConnected('fct', l, NUM_ACTIONS, nl=tf.identity)
if self.method != 'Dueling':
Q = FullyConnected('fct', l, self.num_actions, nl=tf.identity)
else:
# Dueling DQN
V = FullyConnected('fctV', l, 1, nl=tf.identity)
As = FullyConnected('fctA', l, NUM_ACTIONS, nl=tf.identity)
As = FullyConnected('fctA', l, self.num_actions, nl=tf.identity)
Q = tf.add(As, V - tf.reduce_mean(As, 1, keep_dims=True))
return tf.identity(Q, name='Qvalue')
def _build_graph(self, inputs):
comb_state, action, reward, isOver = inputs
comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, 4], name='state')
self.predict_value = self._get_DQN_prediction(state)
if not get_current_tower_context().is_training:
return
reward = tf.clip_by_value(reward, -1, 1)
next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, 4], name='next_state')
action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'), \
collection.freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]):
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
if METHOD != 'Double':
# DQN
best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
else:
# Double-DQN
sc = tf.get_variable_scope()
with tf.variable_scope(sc, reuse=True):
next_predict_value = self._get_DQN_prediction(next_state)
self.greedy_choice = tf.argmax(next_predict_value, 1) # N,
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
self.cost = tf.reduce_mean(symbf.huber_loss(
target - pred_action_value), name='cost')
summary.add_param_summary(('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])) # monitor all W
summary.add_moving_summary(self.cost)
def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 1e-3, summary=True)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
return optimizer.apply_grad_processors(
opt, [gradproc.GlobalNormClip(10), gradproc.SummaryGradient()])
def get_config():
logger.auto_set_dir()
M = Model()
M = Model(IMAGE_SIZE, CHANNEL, METHOD, NUM_ACTIONS, GAMMA)
expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True),
......@@ -170,28 +110,17 @@ def get_config():
history_len=FRAME_HISTORY
)
def update_target_param():
vars = tf.global_variables()
ops = []
G = tf.get_default_graph()
for v in vars:
target_name = v.op.name
if target_name.startswith('target'):
new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network')
return TrainConfig(
dataflow=expreplay,
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
RunOp(update_target_param),
RunOp(DQNModel.update_target_param),
expreplay,
PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['Qvalue']), every_k_epochs=5),
EVAL_EPISODE, ['state'], ['Qvalue'], get_player),
every_k_epochs=5),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
],
......@@ -232,9 +161,9 @@ if __name__ == '__main__':
input_names=['state'],
output_names=['Qvalue'])
if args.task == 'play':
play_model(cfg)
play_model(cfg, get_player(viz=0.01))
elif args.task == 'eval':
eval_model_multithread(cfg, EVAL_EPISODE)
eval_model_multithread(cfg, EVAL_EPISODE, get_player)
else:
config = get_config()
if args.load:
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: DQNModel.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import abc
import tensorflow as tf
from tensorpack import ModelDesc, InputDesc
from tensorpack.utils import logger
from tensorpack.tfutils import (
collection, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils import symbolic_functions as symbf
class Model(ModelDesc):
def __init__(self, image_shape, channel, method, num_actions, gamma):
self.image_shape = image_shape
self.channel = channel
self.method = method
self.num_actions = num_actions
self.gamma = gamma
def _get_inputs(self):
# use a combined state, where the first channels are the current state,
# and the last 4 channels are the next state
return [InputDesc(tf.uint8,
(None,) + self.image_shape + (self.channel + 1,),
'comb_state'),
InputDesc(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'reward'),
InputDesc(tf.bool, (None,), 'isOver')]
@abc.abstractmethod
def _get_DQN_prediction(self, image):
pass
def _build_graph(self, inputs):
comb_state, action, reward, isOver = inputs
comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, 4], name='state')
self.predict_value = self._get_DQN_prediction(state)
if not get_current_tower_context().is_training:
return
reward = tf.clip_by_value(reward, -1, 1)
next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, 4], name='next_state')
action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'), \
collection.freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]):
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
if self.method != 'Double':
# DQN
best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
else:
# Double-DQN
sc = tf.get_variable_scope()
with tf.variable_scope(sc, reuse=True):
next_predict_value = self._get_DQN_prediction(next_state)
self.greedy_choice = tf.argmax(next_predict_value, 1) # N,
predict_onehot = tf.one_hot(self.greedy_choice, self.num_actions, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * self.gamma * tf.stop_gradient(best_v)
self.cost = tf.reduce_mean(symbf.huber_loss(
target - pred_action_value), name='cost')
summary.add_param_summary(('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])) # monitor all W
summary.add_moving_summary(self.cost)
def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 1e-3, summary=True)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
return optimizer.apply_grad_processors(
opt, [gradproc.GlobalNormClip(10), gradproc.SummaryGradient()])
@staticmethod
def update_target_param():
vars = tf.global_variables()
ops = []
G = tf.get_default_graph()
for v in vars:
target_name = v.op.name
if target_name.startswith('target'):
new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network')
......@@ -14,9 +14,6 @@ from tensorpack import *
from tensorpack.utils.concurrency import *
from tensorpack.utils.stats import *
global get_player
get_player = None
def play_one_episode(player, func, verbose=False):
def f(s):
......@@ -30,15 +27,14 @@ def play_one_episode(player, func, verbose=False):
return np.mean(player.play_one_episode(f))
def play_model(cfg):
player = get_player(viz=0.01)
def play_model(cfg, player):
predfunc = OfflinePredictor(cfg)
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_with_funcs(predictors, nr_eval):
def eval_with_funcs(predictors, nr_eval, get_player_fn):
class Worker(StoppableThread, ShareSessionThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
......@@ -52,7 +48,7 @@ def eval_with_funcs(predictors, nr_eval):
def run(self):
with self.default_sess():
player = get_player(train=False)
player = get_player_fn(train=False)
while not self.stopped():
try:
score = play_one_episode(player, self.func)
......@@ -88,18 +84,19 @@ def eval_with_funcs(predictors, nr_eval):
return (0, 0)
def eval_model_multithread(cfg, nr_eval):
def eval_model_multithread(cfg, nr_eval, get_player_fn):
func = OfflinePredictor(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval, get_player_fn)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Triggerable):
def __init__(self, nr_eval, input_names, output_names):
def __init__(self, nr_eval, input_names, output_names, get_player_fn):
self.eval_episode = nr_eval
self.input_names = input_names
self.output_names = output_names
self.get_player_fn = get_player_fn
def _setup_graph(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 20)
......@@ -108,7 +105,8 @@ class Evaluator(Triggerable):
def _trigger(self):
t = time.time()
mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode)
mean, max = eval_with_funcs(
self.pred_funcs, self.eval_episode, self.get_player_fn)
t = time.time() - t
if t > 10 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.94)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment