Commit f1e3b3ae authored by Yuxin Wu's avatar Yuxin Wu

use sessioncreator API for supervisor

parent c04c1ef8
...@@ -4,19 +4,14 @@ ...@@ -4,19 +4,14 @@
import tensorflow as tf import tensorflow as tf
from six.moves import range from six.moves import range
import weakref
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from ..utils import logger from ..utils import logger
from .input_source import StagingInputWrapper, FeedfreeInput from .input_source import StagingInputWrapper
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase from .multigpu import MultiGPUTrainerBase
from ..tfutils.model_utils import describe_model from ..callbacks import RunOp
from ..callbacks import Callbacks, ProgressBar from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..tfutils.common import get_default_sess_config, get_global_step_var, get_op_tensor_name
from ..callbacks.monitor import Monitors
__all__ = ['DistributedReplicatedTrainer'] __all__ = ['DistributedReplicatedTrainer']
...@@ -160,6 +155,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -160,6 +155,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return # TODO exit and skip mainloop how? return # TODO exit and skip mainloop how?
super(DistributedReplicatedTrainer, self)._setup() super(DistributedReplicatedTrainer, self)._setup()
with tf.device(self.param_server_device):
get_global_step_var()
self.model.get_optimizer() # TODO in global scope, not local
with tf.variable_scope( with tf.variable_scope(
tf.get_variable_scope(), tf.get_variable_scope(),
custom_getter=OverrideToLocalVariableIfNotPsVar()): custom_getter=OverrideToLocalVariableIfNotPsVar()):
...@@ -177,44 +176,36 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -177,44 +176,36 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
main_fetch = tf.group(*var_update_ops, name='main_fetches') main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [main_fetch]) self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [main_fetch])
self.post_init_op = self.get_post_init_ops() self.register_callback(RunOp(
self.get_post_init_ops, run_before=True, run_as_trigger=False))
def setup(self):
with tf.device(self.param_server_device): self._set_session_creator()
gs = get_global_step_var()
opt = self.model.get_optimizer() # in global scope, not local def _set_session_creator(self):
self._setup() old_sess_creator = self.config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \
self.monitors = Monitors(self.monitors) or self.config.session_config is not None:
self.register_callback(self.monitors) raise ValueError(
describe_model() "Cannot set session_creator or session_config for distributed training! "
logger.info("Setup callbacks graph ...") "To use a custom session config, pass it to the tf.train.Server constructor.")
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) class SupervisedSessionCreator(tf.train.SessionCreator):
def __init__(self, is_chief, target):
logger.info("Finalize the graph, create the session ...") self.is_chief = is_chief
self.target = target
self.sv = tf.train.Supervisor(
is_chief=self.is_chief, def create_session(self):
logdir=None, # supervisor will finalize the graph..
saver=None, self.sv = tf.train.Supervisor(
global_step=gs, is_chief=self.is_chief,
summary_op=None, logdir=None, saver=None,
save_model_secs=0, global_step=get_global_step_var(),
summary_writer=None) summary_op=None, save_model_secs=0, summary_writer=None)
sess = self.sv.prepare_or_wait_for_session( return self.sv.prepare_or_wait_for_session(
master=self.server.target, master=self.target, start_standard_services=False)
start_standard_services=False)
self.config.session_creator = SupervisedSessionCreator(
self.sess = sess self.is_chief, self.server.target)
logger.info("Running post init op...")
sess.run(self.post_init_op)
logger.info("Post init op finished.")
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
def add_sync_queues_and_barrier(self, name_prefix, enqueue_after_list): def add_sync_queues_and_barrier(self, name_prefix, enqueue_after_list):
"""Adds ops to enqueue on all worker queues. """Adds ops to enqueue on all worker queues.
...@@ -272,5 +263,5 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -272,5 +263,5 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
copy_to = local_var_by_name[name] copy_to = local_var_by_name[name]
post_init_ops.append(copy_to.assign(v.read_value())) post_init_ops.append(copy_to.assign(v.read_value()))
else: else:
logger.warn("Global var {} doesn't match local var".format(v.name)) logger.warn("Global varable {} doesn't match a corresponding local var".format(v.name))
return tf.group(*post_init_ops, name='post_init_ops') return tf.group(*post_init_ops, name='post_init_ops')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment