Commit 319e7591 authored by Yuxin Wu's avatar Yuxin Wu

add test for PTB; fix A3C.

parent fa025551
......@@ -124,6 +124,10 @@ def get_config():
lambda: ptb_producer(data3[1], BATCH, SEQ_LEN),
(data3[1].shape[0] // BATCH - 1) // SEQ_LEN)
test_data = TensorInput(
lambda: ptb_producer(data3[2], BATCH, SEQ_LEN),
(data3[2].shape[0] // BATCH - 1) // SEQ_LEN)
M = Model()
return TrainConfig(
data=train_data,
......@@ -135,12 +139,20 @@ def get_config():
lambda e, x: x * 0.80 if e > 6 else x),
RunOp(lambda: M.reset_lstm_state()),
FeedfreeInferenceRunner(val_data, [ScalarStats(['cost'])]),
RunOp(lambda: M.reset_lstm_state()),
FeedfreeInferenceRunner(
test_data,
[ScalarStats(['cost'], prefix='test')], prefix='test'),
RunOp(lambda: M.reset_lstm_state()),
CallbackFactory(
trigger_epoch=lambda self:
self.trainer.add_scalar_summary(
[self.trainer.add_scalar_summary(
'validation_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN))),
RunOp(lambda: M.reset_lstm_state()),
np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN)),
self.trainer.add_scalar_summary(
'test_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('test_cost') / SEQ_LEN))]
),
],
max_epoch=70,
)
......
......@@ -190,12 +190,12 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
grad_list = FilterNoneGrad().process(grad_list)
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False)
grad_list = gradproc.process(grad_list)
grad_list = [gradproc.process(gv) for gv in grad_list]
# use grad from the first tower for iteration in main thread
self.train_op = self.config.optimizer.apply_gradients(grad_list[0], name='min_op')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment