Commit b61a2d89 authored by Yuxin Wu's avatar Yuxin Wu

improve dqn

parent 77755875
...@@ -34,10 +34,6 @@ ACTION_REPEAT = 4 ...@@ -34,10 +34,6 @@ ACTION_REPEAT = 4
GAMMA = 0.99 GAMMA = 0.99
INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory. # NOTE: will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE = 5e4 INIT_MEMORY_SIZE = 5e4
...@@ -73,18 +69,10 @@ class Model(DQNModel): ...@@ -73,18 +69,10 @@ class Model(DQNModel):
with argscope(Conv2D, nl=PReLU.symbolic_function, use_bias=True), \ with argscope(Conv2D, nl=PReLU.symbolic_function, use_bias=True), \
argscope(LeakyReLU, alpha=0.01): argscope(LeakyReLU, alpha=0.01):
l = (LinearWrap(image) l = (LinearWrap(image)
.Conv2D('conv0', out_channel=32, kernel_shape=5)
.MaxPooling('pool0', 2)
.Conv2D('conv1', out_channel=32, kernel_shape=5)
.MaxPooling('pool1', 2)
.Conv2D('conv2', out_channel=64, kernel_shape=4)
.MaxPooling('pool2', 2)
.Conv2D('conv3', out_channel=64, kernel_shape=3)
# the original arch is 2x faster # the original arch is 2x faster
# .Conv2D('conv0', out_channel=32, kernel_shape=8, stride=4) .Conv2D('conv0', out_channel=32, kernel_shape=8, stride=4)
# .Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2) .Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
# .Conv2D('conv2', out_channel=64, kernel_shape=3) .Conv2D('conv2', out_channel=64, kernel_shape=3)
.FullyConnected('fc0', 512, nl=LeakyReLU)()) .FullyConnected('fc0', 512, nl=LeakyReLU)())
if self.method != 'Dueling': if self.method != 'Dueling':
...@@ -108,9 +96,7 @@ def get_config(): ...@@ -108,9 +96,7 @@ def get_config():
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
exploration=INIT_EXPLORATION, init_exploration=1.0,
end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
update_frequency=4, update_frequency=4,
history_len=FRAME_HISTORY history_len=FRAME_HISTORY
) )
...@@ -121,6 +107,10 @@ def get_config(): ...@@ -121,6 +107,10 @@ def get_config():
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]), [(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
ScheduledHyperParamSetter(
ObjAttrParam(expreplay, 'exploration'),
[(0, 1), (100, 0.1), (200, 0.01)],
interp='linear'),
RunOp(DQNModel.update_target_param), RunOp(DQNModel.update_target_param),
expreplay, expreplay,
PeriodicTrigger(Evaluator( PeriodicTrigger(Evaluator(
......
...@@ -123,7 +123,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -123,7 +123,7 @@ class ExpReplay(DataFlow, Callback):
state_shape, state_shape,
batch_size, batch_size,
memory_size, init_memory_size, memory_size, init_memory_size,
exploration, end_exploration, exploration_epoch_anneal, init_exploration,
update_frequency, history_len): update_frequency, history_len):
""" """
Args: Args:
...@@ -140,6 +140,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -140,6 +140,7 @@ class ExpReplay(DataFlow, Callback):
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.exploration = init_exploration
self.num_actions = player.get_action_space().num_actions() self.num_actions = player.get_action_space().num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
...@@ -245,9 +246,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -245,9 +246,6 @@ class ExpReplay(DataFlow, Callback):
self._simulator_th.start() self._simulator_th.start()
def _trigger_epoch(self): def _trigger_epoch(self):
if self.exploration > self.end_exploration:
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
# log player statistics # log player statistics
stats = self.player.stats stats = self.player.stats
for k, v in six.iteritems(stats): for k, v in six.iteritems(stats):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment