Commit 399db3ae authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'dqn-improve'

parents 77755875 1f94ae78
...@@ -30,18 +30,15 @@ from expreplay import ExpReplay ...@@ -30,18 +30,15 @@ from expreplay import ExpReplay
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4
GAMMA = 0.99 GAMMA = 0.99
INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory. # will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE = 5e4 INIT_MEMORY_SIZE = 5e4
STEPS_PER_EPOCH = 10000 STEPS_PER_EPOCH = 10000 // UPDATE_FREQ * 10 # each epoch is 100k played frames
EVAL_EPISODE = 50 EVAL_EPISODE = 50
NUM_ACTIONS = None NUM_ACTIONS = None
...@@ -73,18 +70,19 @@ class Model(DQNModel): ...@@ -73,18 +70,19 @@ class Model(DQNModel):
with argscope(Conv2D, nl=PReLU.symbolic_function, use_bias=True), \ with argscope(Conv2D, nl=PReLU.symbolic_function, use_bias=True), \
argscope(LeakyReLU, alpha=0.01): argscope(LeakyReLU, alpha=0.01):
l = (LinearWrap(image) l = (LinearWrap(image)
.Conv2D('conv0', out_channel=32, kernel_shape=5) # Nature architecture
.MaxPooling('pool0', 2) .Conv2D('conv0', out_channel=32, kernel_shape=8, stride=4)
.Conv2D('conv1', out_channel=32, kernel_shape=5) .Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
.MaxPooling('pool1', 2) .Conv2D('conv2', out_channel=64, kernel_shape=3)
.Conv2D('conv2', out_channel=64, kernel_shape=4)
.MaxPooling('pool2', 2) # architecture used for the figure in the README, slower but takes fewer iterations to converge
.Conv2D('conv3', out_channel=64, kernel_shape=3) # .Conv2D('conv0', out_channel=32, kernel_shape=5)
# .MaxPooling('pool0', 2)
# the original arch is 2x faster # .Conv2D('conv1', out_channel=32, kernel_shape=5)
# .Conv2D('conv0', out_channel=32, kernel_shape=8, stride=4) # .MaxPooling('pool1', 2)
# .Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2) # .Conv2D('conv2', out_channel=64, kernel_shape=4)
# .Conv2D('conv2', out_channel=64, kernel_shape=3) # .MaxPooling('pool2', 2)
# .Conv2D('conv3', out_channel=64, kernel_shape=3)
.FullyConnected('fc0', 512, nl=LeakyReLU)()) .FullyConnected('fc0', 512, nl=LeakyReLU)())
if self.method != 'Dueling': if self.method != 'Dueling':
...@@ -98,8 +96,6 @@ class Model(DQNModel): ...@@ -98,8 +96,6 @@ class Model(DQNModel):
def get_config(): def get_config():
logger.auto_set_dir()
M = Model() M = Model()
expreplay = ExpReplay( expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']), predictor_io_names=(['state'], ['Qvalue']),
...@@ -108,10 +104,8 @@ def get_config(): ...@@ -108,10 +104,8 @@ def get_config():
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
exploration=INIT_EXPLORATION, init_exploration=1.0,
end_exploration=END_EXPLORATION, update_frequency=UPDATE_FREQ,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
update_frequency=4,
history_len=FRAME_HISTORY history_len=FRAME_HISTORY
) )
...@@ -119,18 +113,24 @@ def get_config(): ...@@ -119,18 +113,24 @@ def get_config():
dataflow=expreplay, dataflow=expreplay,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', PeriodicTrigger(
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]), RunOp(DQNModel.update_target_param),
RunOp(DQNModel.update_target_param), every_k_steps=10000 // UPDATE_FREQ), # update target network every 10k steps
expreplay, expreplay,
ScheduledHyperParamSetter('learning_rate',
[(60, 4e-4), (100, 2e-4)]),
ScheduledHyperParamSetter(
ObjAttrParam(expreplay, 'exploration'),
[(0, 1), (10, 0.1), (320, 0.01)], # 1->0.1 in the first million steps
interp='linear'),
PeriodicTrigger(Evaluator( PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['Qvalue'], get_player), EVAL_EPISODE, ['state'], ['Qvalue'], get_player),
every_k_epochs=5), every_k_epochs=10),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'), HumanHyperParamSetter('learning_rate'),
# HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
], ],
model=M, model=M,
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
max_epoch=1000,
# run the simulator on a separate GPU if available # run the simulator on a separate GPU if available
predict_tower=[1] if get_nr_gpu() > 1 else [0], predict_tower=[1] if get_nr_gpu() > 1 else [0],
) )
...@@ -170,6 +170,9 @@ if __name__ == '__main__': ...@@ -170,6 +170,9 @@ if __name__ == '__main__':
elif args.task == 'eval': elif args.task == 'eval':
eval_model_multithread(cfg, EVAL_EPISODE, get_player) eval_model_multithread(cfg, EVAL_EPISODE, get_player)
else: else:
logger.set_logger_dir(
'train_log/DQN-{}'.format(
os.path.basename(ROM_FILE).split('.')[0]))
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
...@@ -19,13 +19,10 @@ Claimed performance in the paper can be reproduced, on several games I've tested ...@@ -19,13 +19,10 @@ Claimed performance in the paper can be reproduced, on several games I've tested
![DQN](curve-breakout.png) ![DQN](curve-breakout.png)
DQN typically took 1 day of training to reach a score of 400 on breakout game (same as the paper). On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout game.
My Batch-A3C implementation only took <2 hours. Batch-A3C implementation only took <2 hours. (Both are trained with a larger network noted in the code).
Both were trained on one GPU with an extra GPU for simulation.
Double-DQN runs at 18 batches/s (1152 frames/s) on TitanX. Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on TitanX.
Note that I wasn't using the network architecture in the paper.
If switched to the network in the paper it could run 2x faster.
## How to use ## How to use
......
...@@ -123,7 +123,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -123,7 +123,7 @@ class ExpReplay(DataFlow, Callback):
state_shape, state_shape,
batch_size, batch_size,
memory_size, init_memory_size, memory_size, init_memory_size,
exploration, end_exploration, exploration_epoch_anneal, init_exploration,
update_frequency, history_len): update_frequency, history_len):
""" """
Args: Args:
...@@ -140,13 +140,13 @@ class ExpReplay(DataFlow, Callback): ...@@ -140,13 +140,13 @@ class ExpReplay(DataFlow, Callback):
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.exploration = init_exploration
self.num_actions = player.get_action_space().num_actions() self.num_actions = player.get_action_space().num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized self._init_memory_flag = threading.Event() # tell if memory has been initialized
# TODO just use a semaphore?
# a queue to receive notifications to populate memory # a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5) self._populate_job_queue = queue.Queue(maxsize=5)
...@@ -245,18 +245,15 @@ class ExpReplay(DataFlow, Callback): ...@@ -245,18 +245,15 @@ class ExpReplay(DataFlow, Callback):
self._simulator_th.start() self._simulator_th.start()
def _trigger_epoch(self): def _trigger_epoch(self):
if self.exploration > self.end_exploration: # log player statistics in training
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
# log player statistics
stats = self.player.stats stats = self.player.stats
for k, v in six.iteritems(stats): for k, v in six.iteritems(stats):
try: try:
mean, max = np.mean(v), np.max(v) mean, max = np.mean(v), np.max(v)
self.trainer.add_scalar_summary('expreplay/mean_' + k, mean) self.trainer.monitors.put_scalar('expreplay/mean_' + k, mean)
self.trainer.add_scalar_summary('expreplay/max_' + k, max) self.trainer.monitors.put_scalar('expreplay/max_' + k, max)
except: except:
pass logger.exception("Cannot log training scores.")
self.player.reset_stat() self.player.reset_stat()
......
...@@ -27,7 +27,7 @@ class RunOp(Callback): ...@@ -27,7 +27,7 @@ class RunOp(Callback):
Examples: Examples:
The `DQN Example The `DQN Example
<https://github.com/ppwwyyxx/tensorpack/blob/master/examples/Atari2600/DQN.py#L182>`_ <https://github.com/ppwwyyxx/tensorpack/blob/master/examples/DeepQNetwork/>`_
uses this callback to update target network. uses this callback to update target network.
""" """
self.setup_func = setup_func self.setup_func = setup_func
......
...@@ -217,8 +217,9 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -217,8 +217,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
param: same as in :class:`HyperParamSetter`. param: same as in :class:`HyperParamSetter`.
schedule (list): with the format ``[(epoch1, val1), (epoch2, val2), (epoch3, val3)]``. schedule (list): with the format ``[(epoch1, val1), (epoch2, val2), (epoch3, val3)]``.
Each ``(ep, val)`` pair means to set the param Each ``(ep, val)`` pair means to set the param
to "val" __after__ the completion of `ep` th epoch. to "val" __after__ the completion of epoch `ep`.
If ep == 0, the value will be set before the first epoch. If ep == 0, the value will be set before the first epoch
(by default the first is epoch 1).
interp: None: no interpolation. 'linear': linear interpolation interp: None: no interpolation. 'linear': linear interpolation
Example: Example:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment