Commit 23c643fb authored by Yuxin Wu's avatar Yuxin Wu

remove supervisor and use sessionmanager

parent a53da5ab
......@@ -136,7 +136,7 @@ def layer_register(
# log shape info and add activation
logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs)))
_LAYER_LOGGED.add(scope.name)
_LAYER_LOGGED.add(scope_name)
else:
# run the actual function
outputs = func(*args, **actual_args)
......
......@@ -154,7 +154,8 @@ def add_moving_summary(v, *args, **kwargs):
for x in v:
assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape()
# TODO will produce tower0/xxx?
# TODO will produce variable tower0/xxx?
# TODO not saved under distributed
# TODO use zero_debias
gs = get_global_step_var()
with tf.name_scope(None), tf.device(gs.device):
......
......@@ -215,26 +215,26 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
or self.config.session_config is not None:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to the tf.train.Server constructor.")
# TODO use scaffold + monitored session
class SupervisedSessionCreator(tf.train.SessionCreator):
def __init__(self, is_chief, target):
self.is_chief = is_chief
self.target = target
"To use a custom session config, pass it with tf.train.Server.")
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables()
sm = tf.train.SessionManager(
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
def _create_session():
if self.is_chief:
return sm.prepare_session(master=self.server.target, init_op=init_op)
else:
return sm.wait_for_session(master=self.server.target)
class _Creator(tf.train.SessionCreator):
def create_session(self):
# supervisor will finalize the graph..
self.sv = tf.train.Supervisor(
is_chief=self.is_chief,
logdir=None, saver=None,
global_step=get_global_step_var(),
summary_op=None, save_model_secs=0, summary_writer=None)
return self.sv.prepare_or_wait_for_session(
master=self.target, start_standard_services=False)
self.config.session_creator = SupervisedSessionCreator(
self.is_chief, self.server.target)
return _create_session()
self.config.session_creator = _Creator()
def add_sync_queues_and_barrier(self, name_prefix, enqueue_after_list):
"""Adds ops to enqueue on all worker queues.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment