Commit a8dfad63 authored by Yuxin Wu's avatar Yuxin Wu

match "steps"

parent b61a2d89
...@@ -30,14 +30,15 @@ from expreplay import ExpReplay ...@@ -30,14 +30,15 @@ from expreplay import ExpReplay
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4
GAMMA = 0.99 GAMMA = 0.99
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory. # NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE = 5e4 INIT_MEMORY_SIZE = 5e4
STEPS_PER_EPOCH = 10000 STEPS_PER_EPOCH = 10000 // UPDATE_FREQ * 10 # each epoch is 100k played frames
EVAL_EPISODE = 50 EVAL_EPISODE = 50
NUM_ACTIONS = None NUM_ACTIONS = None
...@@ -97,7 +98,7 @@ def get_config(): ...@@ -97,7 +98,7 @@ def get_config():
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
init_exploration=1.0, init_exploration=1.0,
update_frequency=4, update_frequency=UPDATE_FREQ,
history_len=FRAME_HISTORY history_len=FRAME_HISTORY
) )
...@@ -106,21 +107,24 @@ def get_config(): ...@@ -106,21 +107,24 @@ def get_config():
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]), [(60, 4e-4), (100, 2e-4)]),
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
ObjAttrParam(expreplay, 'exploration'), ObjAttrParam(expreplay, 'exploration'),
[(0, 1), (100, 0.1), (200, 0.01)], [(0, 1), (10, 0.1), (240, 0.01)],
interp='linear'), interp='linear'),
RunOp(DQNModel.update_target_param), PeriodicTrigger(
RunOp(DQNModel.update_target_param),
every_k_steps=10000 // UPDATE_FREQ),
expreplay, expreplay,
PeriodicTrigger(Evaluator( PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['Qvalue'], get_player), EVAL_EPISODE, ['state'], ['Qvalue'], get_player),
every_k_epochs=5), every_k_epochs=10),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'), HumanHyperParamSetter('learning_rate'),
# HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'), # HumanHyperParamSetter(ObjAttrParam(expreplay, 'exploration'), 'hyper.txt'),
], ],
model=M, model=M,
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
max_epoch=3000,
# run the simulator on a separate GPU if available # run the simulator on a separate GPU if available
predict_tower=[1] if get_nr_gpu() > 1 else [0], predict_tower=[1] if get_nr_gpu() > 1 else [0],
) )
......
...@@ -147,7 +147,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -147,7 +147,6 @@ class ExpReplay(DataFlow, Callback):
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized self._init_memory_flag = threading.Event() # tell if memory has been initialized
# TODO just use a semaphore?
# a queue to receive notifications to populate memory # a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5) self._populate_job_queue = queue.Queue(maxsize=5)
...@@ -246,15 +245,15 @@ class ExpReplay(DataFlow, Callback): ...@@ -246,15 +245,15 @@ class ExpReplay(DataFlow, Callback):
self._simulator_th.start() self._simulator_th.start()
def _trigger_epoch(self): def _trigger_epoch(self):
# log player statistics # log player statistics in training
stats = self.player.stats stats = self.player.stats
for k, v in six.iteritems(stats): for k, v in six.iteritems(stats):
try: try:
mean, max = np.mean(v), np.max(v) mean, max = np.mean(v), np.max(v)
self.trainer.add_scalar_summary('expreplay/mean_' + k, mean) self.trainer.monitors.put_scalar('expreplay/mean_' + k, mean)
self.trainer.add_scalar_summary('expreplay/max_' + k, max) self.trainer.monitors.put_scalar('expreplay/max_' + k, max)
except: except:
pass logger.exception("Cannot log training scores.")
self.player.reset_stat() self.player.reset_stat()
......
...@@ -27,7 +27,7 @@ class RunOp(Callback): ...@@ -27,7 +27,7 @@ class RunOp(Callback):
Examples: Examples:
The `DQN Example The `DQN Example
<https://github.com/ppwwyyxx/tensorpack/blob/master/examples/Atari2600/DQN.py#L182>`_ <https://github.com/ppwwyyxx/tensorpack/blob/master/examples/DeepQNetwork/>`_
uses this callback to update target network. uses this callback to update target network.
""" """
self.setup_func = setup_func self.setup_func = setup_func
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment