Commit c8028236 authored by Yuxin Wu's avatar Yuxin Wu

fix local variable override bug (fix #430)

parent 8648d571
...@@ -13,7 +13,6 @@ from ..tfutils.sesscreate import NewSessionCreator ...@@ -13,7 +13,6 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name from ..tfutils.common import get_global_step_var, get_op_tensor_name
from .multigpu import MultiGPUTrainerBase from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
__all__ = ['DistributedTrainerReplicated'] __all__ = ['DistributedTrainerReplicated']
...@@ -205,7 +204,6 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -205,7 +204,6 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cbs = self._input_source.setup(self.model.get_inputs_desc()) cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs) self.config.callbacks.extend(cbs)
with override_to_local_variable():
# Ngpu * Nvar * 2 # Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, self.config.tower,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment