Commit 6af22cdd authored by Yuxin Wu's avatar Yuxin Wu

bug fix in super

parent d9e7c6bf
...@@ -111,9 +111,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -111,9 +111,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._average_gradient = average_gradient self._average_gradient = average_gradient
def _setup(self): def _setup(self):
super(SyncMultiGPUTrainer, self)._setup() super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment