Commit ba435f10 authored by Yuxin Wu's avatar Yuxin Wu

update DQN hyperparam

parent 4071cbec
...@@ -20,11 +20,12 @@ from tensorpack.RL import * ...@@ -20,11 +20,12 @@ from tensorpack.RL import *
import common import common
from common import play_model, Evaluator, eval_model_multithread from common import play_model, Evaluator, eval_model_multithread
BATCH_SIZE = 32 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 4
HEIGHT_RANGE = (36, 204) # for breakout HEIGHT_RANGE = (None, None)
#HEIGHT_RANGE = (36, 204) # for breakout
#HEIGHT_RANGE = (28, -8) # for pong #HEIGHT_RANGE = (28, -8) # for pong
CHANNEL = FRAME_HISTORY CHANNEL = FRAME_HISTORY
...@@ -32,7 +33,7 @@ IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,) ...@@ -32,7 +33,7 @@ IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
GAMMA = 0.99 GAMMA = 0.99
INIT_EXPLORATION = 1 INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.008 EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION = 0.1 END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
...@@ -133,7 +134,7 @@ class Model(ModelDesc): ...@@ -133,7 +134,7 @@ class Model(ModelDesc):
SummaryGradient()] SummaryGradient()]
def predictor(self, state): def predictor(self, state):
# TODO change to a multitower predictor for speedup # TODO use multitower predictor to speed up training
return self.predict_value.eval(feed_dict={'state:0': [state]})[0] return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def get_config(): def get_config():
...@@ -155,7 +156,7 @@ def get_config(): ...@@ -155,7 +156,7 @@ def get_config():
reward_clip=(-1, 1), reward_clip=(-1, 1),
history_len=FRAME_HISTORY) history_len=FRAME_HISTORY)
lr = tf.Variable(0.0004, trainable=False, name='learning_rate') lr = tf.Variable(0.001, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
...@@ -164,11 +165,13 @@ def get_config(): ...@@ -164,11 +165,13 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
HumanHyperParamSetter('learning_rate', 'hyper.txt'), HumanHyperParamSetter('learning_rate', 'hyper.txt'),
HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'), HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
RunOp(lambda: M.update_target_param()), RunOp(lambda: M.update_target_param()),
dataset_train, dataset_train,
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['fct/output']), 2), PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['fct/output']), 3),
]), ]),
# save memory for multiprocess evaluator # save memory for multiprocess evaluator
session_config=get_default_sess_config(0.6), session_config=get_default_sess_config(0.6),
...@@ -205,6 +208,6 @@ if __name__ == '__main__': ...@@ -205,6 +208,6 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train() SimpleTrainer(config).train()
# TODO test if queue trainer works
#QueueInputTrainer(config).train() #QueueInputTrainer(config).train()
# TODO test if QueueInput affects learning
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment