Commit c8028236 authored by Yuxin Wu's avatar Yuxin Wu

fix local variable override bug (fix #430)

parent 8648d571
...@@ -13,7 +13,6 @@ from ..tfutils.sesscreate import NewSessionCreator ...@@ -13,7 +13,6 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name from ..tfutils.common import get_global_step_var, get_op_tensor_name
from .multigpu import MultiGPUTrainerBase from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
__all__ = ['DistributedTrainerReplicated'] __all__ = ['DistributedTrainerReplicated']
...@@ -205,15 +204,14 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -205,15 +204,14 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cbs = self._input_source.setup(self.model.get_inputs_desc()) cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs) self.config.callbacks.extend(cbs)
with override_to_local_variable(): # Ngpu * Nvar * 2
# Ngpu * Nvar * 2 grad_list = MultiGPUTrainerBase.build_on_multi_tower(
grad_list = MultiGPUTrainerBase.build_on_multi_tower( self.config.tower,
self.config.tower, lambda: MultiGPUTrainerBase._build_graph_get_grads(
lambda: MultiGPUTrainerBase._build_graph_get_grads( self.model, self._input_source),
self.model, self._input_source), devices=self.raw_devices,
devices=self.raw_devices, use_vs=[True] * self.config.nr_tower) # open vs at each tower
use_vs=[True] * self.config.nr_tower) # open vs at each tower MultiGPUTrainerBase._check_grad_list(grad_list)
MultiGPUTrainerBase._check_grad_list(grad_list)
avg_grads = DistributedTrainerReplicated._average_grads(grad_list, self.raw_devices) avg_grads = DistributedTrainerReplicated._average_grads(grad_list, self.raw_devices)
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment