Commit 12846f57 authored by Yuxin Wu's avatar Yuxin Wu

Maintain global_step in each job as a local_variable

parent 3498c0b6
......@@ -7,9 +7,10 @@ import re
from six.moves import zip, range
from ..utils.argtools import memoized
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .training import DataParallelBuilder
from .utils import override_to_local_variable
__all__ = ['DistributedReplicatedBuilder']
......@@ -178,10 +179,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
tf.Operation: the op which sync all the local `MODEL_VARIABLES` from PS.
You can choose how often to run it by yourself.
"""
# do this before everything, because they my need global step
with tf.device(self.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
with override_to_local_variable():
get_global_step_var()
get_opt_fn = memoized(get_opt_fn)
# Build the optimizer first, before entering any tower.
......
......@@ -55,21 +55,18 @@ def get_default_sess_config(mem_fraction=0.99):
def get_global_step_var():
"""
Returns:
tf.Tensor: the global_step variable in the current graph. create if
doesn't exist.
tf.Tensor: the global_step variable in the current graph. Create if
doesn't exist.
"""
scope = tf.get_variable_scope()
assert scope.name == '', \
"The global_step variable should be created under the root variable scope!"
assert not scope.reuse, \
"The global_step variable shouldn't be called under a reuse variable scope!"
if get_tf_version_number() <= 1.0:
var = tf.get_variable('global_step',
initializer=tf.constant(0, dtype=tf.int64),
trainable=False, dtype=tf.int64)
tf.add_to_collection(tf.GraphKeys.GLOBAL_STEP, var)
else:
var = tf.train.get_or_create_global_step()
scope = tf.VariableScope(reuse=False, name='') # the root vs
with tf.variable_scope(scope):
if get_tf_version_number() <= 1.0:
var = tf.get_variable('global_step',
initializer=tf.constant(0, dtype=tf.int64),
trainable=False, dtype=tf.int64)
tf.add_to_collection(tf.GraphKeys.GLOBAL_STEP, var)
else:
var = tf.train.get_or_create_global_step()
return var
......
......@@ -35,6 +35,12 @@ class MaintainStepCounter(Callback):
It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, you don't need to worry about it.
"""
chief_only = False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
......
......@@ -8,7 +8,7 @@ import os
from ..utils import logger
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var
from ..tfutils import get_global_step_var
from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.utils import override_to_local_variable
......@@ -75,18 +75,14 @@ class DistributedTrainerReplicated(Trainer):
def _setup(self):
if self.job_name == 'ps':
logger.info("Running ps {}".format(self._builder.task_index))
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713
return
# always do this before inputsource.setup because input_source my need global step
# TODO Can we just do this in get_global_step_var
with tf.device(self._builder.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
with override_to_local_variable():
get_global_step_var() # gs should be local
# input source may create variable (queue size summary)
# TODO This is not good because we don't know from here
# whether something should be global or local. We now assume
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment