Commit 621a1bbd authored by Yuxin Wu's avatar Yuxin Wu

Fix wrong variable collection in distributed training (#431)

parent 7632ca9f
...@@ -33,6 +33,8 @@ def _get_cached_vs(name): ...@@ -33,6 +33,8 @@ def _get_cached_vs(name):
@contextmanager @contextmanager
def _enter_vs_reuse_ns(name): def _enter_vs_reuse_ns(name):
vs = _get_cached_vs(name) vs = _get_cached_vs(name)
# XXX Not good to enter the cached vs directly, because this will clean-up custom getter
# with tf.variable_scope(name, reuse=tf.AUTO_REUSE): # available in 1.4 only
with tf.variable_scope(vs): with tf.variable_scope(vs):
with tf.name_scope(vs.original_name_scope): with tf.name_scope(vs.original_name_scope):
yield vs yield vs
......
...@@ -13,6 +13,7 @@ from ..tfutils.sesscreate import NewSessionCreator ...@@ -13,6 +13,7 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name from ..tfutils.common import get_global_step_var, get_op_tensor_name
from .multigpu import MultiGPUTrainerBase from .multigpu import MultiGPUTrainerBase
from .utility import override_to_local_variable
__all__ = ['DistributedTrainerReplicated'] __all__ = ['DistributedTrainerReplicated']
...@@ -180,6 +181,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -180,6 +181,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
var_update_ops = [] var_update_ops = []
for vid, (g, v) in enumerate(ps_var_grads): for vid, (g, v) in enumerate(ps_var_grads):
# TODO do we put momentum variables into local or global?
apply_gradient_op = opt.apply_gradients([(g, v)]) apply_gradient_op = opt.apply_gradients([(g, v)])
barrier = self._add_sync_queues_and_barrier( barrier = self._add_sync_queues_and_barrier(
'param_update_barrier_{}'.format(vid), [apply_gradient_op]) 'param_update_barrier_{}'.format(vid), [apply_gradient_op])
...@@ -201,7 +203,10 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -201,7 +203,10 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
gs = get_global_step_var() gs = get_global_step_var()
assert gs.device, gs.device assert gs.device, gs.device
# do this before inputsource.setup because input_source my need global step # do this before inputsource.setup because input_source my need global step
cbs = self._input_source.setup(self.model.get_inputs_desc())
with override_to_local_variable():
# input source may create variable (queue size summary)
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs) self.config.callbacks.extend(cbs)
# build the optimizer first, before entering any tower # build the optimizer first, before entering any tower
...@@ -258,6 +263,11 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -258,6 +263,11 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
local_init_op=local_init_op, local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph()) ready_op=ready_op, graph=tf.get_default_graph())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
def _create_session(): def _create_session():
if self.is_chief: if self.is_chief:
return sm.prepare_session(master=self.server.target, init_op=init_op) return sm.prepare_session(master=self.server.target, init_op=init_op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment