Commit f1e3b3ae authored by Yuxin Wu's avatar Yuxin Wu

use sessioncreator API for supervisor

parent c04c1ef8
......@@ -4,19 +4,14 @@
import tensorflow as tf
from six.moves import range
import weakref
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from ..utils import logger
from .input_source import StagingInputWrapper, FeedfreeInput
from .input_source import StagingInputWrapper
from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase
from ..tfutils.model_utils import describe_model
from ..callbacks import Callbacks, ProgressBar
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.common import get_default_sess_config, get_global_step_var, get_op_tensor_name
from ..callbacks.monitor import Monitors
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var, get_op_tensor_name
__all__ = ['DistributedReplicatedTrainer']
......@@ -160,6 +155,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return # TODO exit and skip mainloop how?
super(DistributedReplicatedTrainer, self)._setup()
with tf.device(self.param_server_device):
get_global_step_var()
self.model.get_optimizer() # TODO in global scope, not local
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariableIfNotPsVar()):
......@@ -177,44 +176,36 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [main_fetch])
self.post_init_op = self.get_post_init_ops()
def setup(self):
with tf.device(self.param_server_device):
gs = get_global_step_var()
opt = self.model.get_optimizer() # in global scope, not local
self._setup()
self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors)
describe_model()
logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
logger.info("Finalize the graph, create the session ...")
self.register_callback(RunOp(
self.get_post_init_ops, run_before=True, run_as_trigger=False))
self._set_session_creator()
def _set_session_creator(self):
old_sess_creator = self.config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \
or self.config.session_config is not None:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to the tf.train.Server constructor.")
class SupervisedSessionCreator(tf.train.SessionCreator):
def __init__(self, is_chief, target):
self.is_chief = is_chief
self.target = target
def create_session(self):
# supervisor will finalize the graph..
self.sv = tf.train.Supervisor(
is_chief=self.is_chief,
logdir=None,
saver=None,
global_step=gs,
summary_op=None,
save_model_secs=0,
summary_writer=None)
sess = self.sv.prepare_or_wait_for_session(
master=self.server.target,
start_standard_services=False)
self.sess = sess
logger.info("Running post init op...")
sess.run(self.post_init_op)
logger.info("Post init op finished.")
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
logdir=None, saver=None,
global_step=get_global_step_var(),
summary_op=None, save_model_secs=0, summary_writer=None)
return self.sv.prepare_or_wait_for_session(
master=self.target, start_standard_services=False)
self.config.session_creator = SupervisedSessionCreator(
self.is_chief, self.server.target)
def add_sync_queues_and_barrier(self, name_prefix, enqueue_after_list):
"""Adds ops to enqueue on all worker queues.
......@@ -272,5 +263,5 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
copy_to = local_var_by_name[name]
post_init_ops.append(copy_to.assign(v.read_value()))
else:
logger.warn("Global var {} doesn't match local var".format(v.name))
logger.warn("Global varable {} doesn't match a corresponding local var".format(v.name))
return tf.group(*post_init_ops, name='post_init_ops')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment