Commit 12846f57 authored by Yuxin Wu's avatar Yuxin Wu

Maintain global_step in each job as a local_variable

parent 3498c0b6
...@@ -7,9 +7,10 @@ import re ...@@ -7,9 +7,10 @@ import re
from six.moves import zip, range from six.moves import zip, range
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.common import get_global_step_var, get_op_tensor_name from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .training import DataParallelBuilder from .training import DataParallelBuilder
from .utils import override_to_local_variable
__all__ = ['DistributedReplicatedBuilder'] __all__ = ['DistributedReplicatedBuilder']
...@@ -178,10 +179,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -178,10 +179,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
tf.Operation: the op which sync all the local `MODEL_VARIABLES` from PS. tf.Operation: the op which sync all the local `MODEL_VARIABLES` from PS.
You can choose how often to run it by yourself. You can choose how often to run it by yourself.
""" """
# do this before everything, because they my need global step with override_to_local_variable():
with tf.device(self.param_server_device): get_global_step_var()
gs = get_global_step_var()
assert gs.device, gs.device
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
# Build the optimizer first, before entering any tower. # Build the optimizer first, before entering any tower.
......
...@@ -55,21 +55,18 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -55,21 +55,18 @@ def get_default_sess_config(mem_fraction=0.99):
def get_global_step_var(): def get_global_step_var():
""" """
Returns: Returns:
tf.Tensor: the global_step variable in the current graph. create if tf.Tensor: the global_step variable in the current graph. Create if
doesn't exist. doesn't exist.
""" """
scope = tf.get_variable_scope() scope = tf.VariableScope(reuse=False, name='') # the root vs
assert scope.name == '', \ with tf.variable_scope(scope):
"The global_step variable should be created under the root variable scope!" if get_tf_version_number() <= 1.0:
assert not scope.reuse, \ var = tf.get_variable('global_step',
"The global_step variable shouldn't be called under a reuse variable scope!" initializer=tf.constant(0, dtype=tf.int64),
if get_tf_version_number() <= 1.0: trainable=False, dtype=tf.int64)
var = tf.get_variable('global_step', tf.add_to_collection(tf.GraphKeys.GLOBAL_STEP, var)
initializer=tf.constant(0, dtype=tf.int64), else:
trainable=False, dtype=tf.int64) var = tf.train.get_or_create_global_step()
tf.add_to_collection(tf.GraphKeys.GLOBAL_STEP, var)
else:
var = tf.train.get_or_create_global_step()
return var return var
......
...@@ -35,6 +35,12 @@ class MaintainStepCounter(Callback): ...@@ -35,6 +35,12 @@ class MaintainStepCounter(Callback):
It maintains the global step in the graph, making sure it's increased by one. It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, you don't need to worry about it. This callback is always enabled by the trainer, you don't need to worry about it.
""" """
chief_only = False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def _setup_graph(self): def _setup_graph(self):
# ensure it exists # ensure it exists
gs_var = get_global_step_var() gs_var = get_global_step_var()
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
from ..utils import logger from ..utils import logger
from ..callbacks import RunOp from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var from ..tfutils import get_global_step_var
from ..graph_builder.distributed import DistributedReplicatedBuilder from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.utils import override_to_local_variable from ..graph_builder.utils import override_to_local_variable
...@@ -75,18 +75,14 @@ class DistributedTrainerReplicated(Trainer): ...@@ -75,18 +75,14 @@ class DistributedTrainerReplicated(Trainer):
def _setup(self): def _setup(self):
if self.job_name == 'ps': if self.job_name == 'ps':
logger.info("Running ps {}".format(self._builder.task_index)) logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid())) logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713 self.server.join() # this will never return tensorflow#4713
return return
# always do this before inputsource.setup because input_source my need global step
# TODO Can we just do this in get_global_step_var
with tf.device(self._builder.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
with override_to_local_variable(): with override_to_local_variable():
get_global_step_var() # gs should be local
# input source may create variable (queue size summary) # input source may create variable (queue size summary)
# TODO This is not good because we don't know from here # TODO This is not good because we don't know from here
# whether something should be global or local. We now assume # whether something should be global or local. We now assume
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment