Commit 238603b7 authored by Yuxin Wu's avatar Yuxin Wu

Don't let NewSessionCreator finalize graph

parent 4f529aed
...@@ -10,24 +10,25 @@ __all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter'] ...@@ -10,24 +10,25 @@ __all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
""" """
A SessionCreator should: A SessionCreator should:
(optionally) finalize the graph
create the session create the session
initialize all variables initialize all variables
return a session that is ready to use return a session that is ready to use
not finalize the graph
""" """
class NewSessionCreator(tf.train.ChiefSessionCreator): class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', graph=None, config=None): def __init__(self, target='', config=None):
""" """
Args: Args:
target, graph, config: same as :meth:`Session.__init__()`. target, config: same as :meth:`Session.__init__()`.
config: a :class:`tf.ConfigProto` instance, defaults to :func:`tfutils.get_default_sess_config()` config: a :class:`tf.ConfigProto` instance, defaults to :func:`tfutils.get_default_sess_config()`
""" """
assert graph is None self.config = config
self.target = target
if config is None: if config is None:
# distributd trainer doesn't support user-provided config # distributed trainer doesn't support user-provided config
# we set this attribute so that they can check # we set this attribute so that they can check
self.user_provided_config = False self.user_provided_config = False
config = get_default_sess_config() config = get_default_sess_config()
...@@ -37,8 +38,11 @@ class NewSessionCreator(tf.train.ChiefSessionCreator): ...@@ -37,8 +38,11 @@ class NewSessionCreator(tf.train.ChiefSessionCreator):
"User-provided custom session config may not work due to TF \ "User-provided custom session config may not work due to TF \
bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.") bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
self.config = config def create_session(self):
super(NewSessionCreator, self).__init__(master=target, config=config) sess = tf.Session(target=self.target, config=self.config)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
return sess
class ReuseSessionCreator(tf.train.SessionCreator): class ReuseSessionCreator(tf.train.SessionCreator):
......
...@@ -214,7 +214,7 @@ class Trainer(object): ...@@ -214,7 +214,7 @@ class Trainer(object):
if not isinstance(session_init, JustCurrentSession): if not isinstance(session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!") logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize() # possibly already finalized by ChiefSessionCreator self.sess.graph.finalize()
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
@call_only_once @call_only_once
......
...@@ -404,11 +404,12 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -404,11 +404,12 @@ class HorovodTrainer(SingleCostTrainer):
@HIDE_DOC @HIDE_DOC
def initialize(self, session_creator, session_init): def initialize(self, session_creator, session_init):
# broadcast_op should be the last setup_graph: it needs to be created # broadcast_op should be the last setup_graph: it needs to be created
# "right before" the session is initialized, # "right before" the graph is finalized,
# because it needs to capture all the variables (which may be created by callbacks). # because it needs to capture all the variables (which may be created by callbacks).
with tf.name_scope('horovod_broadcast'): with tf.name_scope('horovod_broadcast'):
self._broadcast_op = hvd.broadcast_global_variables(0) self._broadcast_op = hvd.broadcast_global_variables(0)
# it's important that our NewSessionCreator does not finalize the graph
if not isinstance(session_creator, NewSessionCreator): if not isinstance(session_creator, NewSessionCreator):
raise ValueError( raise ValueError(
"session_creator has to be `NewSessionCreator` for horovod training! ") "session_creator has to be `NewSessionCreator` for horovod training! ")
...@@ -423,6 +424,9 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -423,6 +424,9 @@ class HorovodTrainer(SingleCostTrainer):
# This broadcast belongs to the "intialize" stage # This broadcast belongs to the "intialize" stage
# It should not be delayed to the "before_train" stage. # It should not be delayed to the "before_train" stage.
# TODO:
# 1. a allgather helper to concat strings
# 2. check variables on each rank match each other, print warnings, and broadcast the common set.
logger.info("Broadcasting initialized variables ...") logger.info("Broadcasting initialized variables ...")
self.sess.run(self._broadcast_op) self.sess.run(self._broadcast_op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment