Commit ba435f10 authored by Yuxin Wu's avatar Yuxin Wu

update DQN hyperparam

parent 4071cbec
......@@ -20,11 +20,12 @@ from tensorpack.RL import *
import common
from common import play_model, Evaluator, eval_model_multithread
BATCH_SIZE = 32
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4
ACTION_REPEAT = 4
HEIGHT_RANGE = (36, 204) # for breakout
HEIGHT_RANGE = (None, None)
#HEIGHT_RANGE = (36, 204) # for breakout
#HEIGHT_RANGE = (28, -8) # for pong
CHANNEL = FRAME_HISTORY
......@@ -32,7 +33,7 @@ IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
GAMMA = 0.99
INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.008
EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6
......@@ -133,7 +134,7 @@ class Model(ModelDesc):
SummaryGradient()]
def predictor(self, state):
# TODO change to a multitower predictor for speedup
# TODO use multitower predictor to speed up training
return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def get_config():
......@@ -155,7 +156,7 @@ def get_config():
reward_clip=(-1, 1),
history_len=FRAME_HISTORY)
lr = tf.Variable(0.0004, trainable=False, name='learning_rate')
lr = tf.Variable(0.001, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
......@@ -164,11 +165,13 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
HumanHyperParamSetter('learning_rate', 'hyper.txt'),
HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
RunOp(lambda: M.update_target_param()),
dataset_train,
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['fct/output']), 2),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['fct/output']), 3),
]),
# save memory for multiprocess evaluator
session_config=get_default_sess_config(0.6),
......@@ -205,6 +208,6 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
# TODO test if queue trainer works
#QueueInputTrainer(config).train()
# TODO test if QueueInput affects learning
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment