Commit c8028236 authored by Yuxin Wu's avatar Yuxin Wu

fix local variable override bug (fix #430)

parent 8648d571
......@@ -13,7 +13,6 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
__all__ = ['DistributedTrainerReplicated']
......@@ -205,15 +204,14 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
with override_to_local_variable():
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source),
devices=self.raw_devices,
use_vs=[True] * self.config.nr_tower) # open vs at each tower
MultiGPUTrainerBase._check_grad_list(grad_list)
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source),
devices=self.raw_devices,
use_vs=[True] * self.config.nr_tower) # open vs at each tower
MultiGPUTrainerBase._check_grad_list(grad_list)
avg_grads = DistributedTrainerReplicated._average_grads(grad_list, self.raw_devices)
with tf.device(self.param_server_device):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment