Commit 238603b7 authored by Yuxin Wu's avatar Yuxin Wu

Don't let NewSessionCreator finalize graph

parent 4f529aed
......@@ -10,24 +10,25 @@ __all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
"""
A SessionCreator should:
(optionally) finalize the graph
create the session
initialize all variables
return a session that is ready to use
not finalize the graph
"""
class NewSessionCreator(tf.train.ChiefSessionCreator):
def __init__(self, target='', graph=None, config=None):
class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', config=None):
"""
Args:
target, graph, config: same as :meth:`Session.__init__()`.
target, config: same as :meth:`Session.__init__()`.
config: a :class:`tf.ConfigProto` instance, defaults to :func:`tfutils.get_default_sess_config()`
"""
assert graph is None
self.config = config
self.target = target
if config is None:
# distributd trainer doesn't support user-provided config
# distributed trainer doesn't support user-provided config
# we set this attribute so that they can check
self.user_provided_config = False
config = get_default_sess_config()
......@@ -37,8 +38,11 @@ class NewSessionCreator(tf.train.ChiefSessionCreator):
"User-provided custom session config may not work due to TF \
bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
self.config = config
super(NewSessionCreator, self).__init__(master=target, config=config)
def create_session(self):
sess = tf.Session(target=self.target, config=self.config)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
return sess
class ReuseSessionCreator(tf.train.SessionCreator):
......
......@@ -214,7 +214,7 @@ class Trainer(object):
if not isinstance(session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize() # possibly already finalized by ChiefSessionCreator
self.sess.graph.finalize()
logger.info("Graph Finalized.")
@call_only_once
......
......@@ -404,11 +404,12 @@ class HorovodTrainer(SingleCostTrainer):
@HIDE_DOC
def initialize(self, session_creator, session_init):
# broadcast_op should be the last setup_graph: it needs to be created
# "right before" the session is initialized,
# "right before" the graph is finalized,
# because it needs to capture all the variables (which may be created by callbacks).
with tf.name_scope('horovod_broadcast'):
self._broadcast_op = hvd.broadcast_global_variables(0)
# it's important that our NewSessionCreator does not finalize the graph
if not isinstance(session_creator, NewSessionCreator):
raise ValueError(
"session_creator has to be `NewSessionCreator` for horovod training! ")
......@@ -423,6 +424,9 @@ class HorovodTrainer(SingleCostTrainer):
# This broadcast belongs to the "intialize" stage
# It should not be delayed to the "before_train" stage.
# TODO:
# 1. a allgather helper to concat strings
# 2. check variables on each rank match each other, print warnings, and broadcast the common set.
logger.info("Broadcasting initialized variables ...")
self.sess.run(self._broadcast_op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment