Commit c8028236 authored by Yuxin Wu's avatar Yuxin Wu

fix local variable override bug (fix #430)

parent 8648d571
......@@ -13,7 +13,6 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
__all__ = ['DistributedTrainerReplicated']
......@@ -205,7 +204,6 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
with override_to_local_variable():
# Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment