Commit 621a1bbd authored by Yuxin Wu's avatar Yuxin Wu

Fix wrong variable collection in distributed training (#431)

parent 7632ca9f
......@@ -33,6 +33,8 @@ def _get_cached_vs(name):
@contextmanager
def _enter_vs_reuse_ns(name):
vs = _get_cached_vs(name)
# XXX Not good to enter the cached vs directly, because this will clean-up custom getter
# with tf.variable_scope(name, reuse=tf.AUTO_REUSE): # available in 1.4 only
with tf.variable_scope(vs):
with tf.name_scope(vs.original_name_scope):
yield vs
......
......@@ -13,6 +13,7 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
__all__ = ['DistributedTrainerReplicated']
......@@ -180,6 +181,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
opt = self.model.get_optimizer()
var_update_ops = []
for vid, (g, v) in enumerate(ps_var_grads):
# TODO do we put momentum variables into local or global?
apply_gradient_op = opt.apply_gradients([(g, v)])
barrier = self._add_sync_queues_and_barrier(
'param_update_barrier_{}'.format(vid), [apply_gradient_op])
......@@ -201,6 +203,9 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
gs = get_global_step_var()
assert gs.device, gs.device
# do this before inputsource.setup because input_source my need global step
with override_to_local_variable():
# input source may create variable (queue size summary)
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
......@@ -258,6 +263,11 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
def _create_session():
if self.is_chief:
return sm.prepare_session(master=self.server.target, init_op=init_op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment