Commit c99d013a authored by Yuxin Wu's avatar Yuxin Wu

Allow sesscreate to finalize the graph.

parent 93a177bf
......@@ -68,8 +68,7 @@ class ModelSaver(Callback):
keep_checkpoint_every_n_hours=self._keep_every_n_hours,
write_version=tf.train.SaverDef.V2,
save_relative_paths=True)
# Don't know how it can be useful,
# but since there is a predefined key, why not use it?
# Scaffold will call saver.build from this collection
tf.add_to_collection(tf.GraphKeys.SAVERS, self.saver)
def _before_train(self):
......
......@@ -5,24 +5,27 @@
import tensorflow as tf
from .common import get_default_sess_config
from ..utils import logger
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
"""
SessionCreator should return a session that is ready to use
(i.e. variables are initialized)
A SessionCreator should:
(optionally) finalize the graph
create the session
initialize all variables
return a session that is ready to use
"""
class NewSessionCreator(tf.train.SessionCreator):
class NewSessionCreator(tf.train.ChiefSessionCreator):
def __init__(self, target='', graph=None, config=None):
"""
Args:
target, graph, config: same as :meth:`Session.__init__()`.
config: defaults to :func:`tfutils.get_default_sess_config()`
"""
self.target = target
assert graph is None
if config is None:
# distributd trainer doesn't support user-provided config
# we set this attribute so that they can check
......@@ -31,15 +34,7 @@ class NewSessionCreator(tf.train.SessionCreator):
else:
self.user_provided_config = True
self.config = config
self.graph = graph
def create_session(self):
sess = tf.Session(target=self.target, graph=self.graph, config=self.config)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
logger.info("Global and local variables initialized.")
return sess
super(NewSessionCreator, self).__init__(master=target, config=config)
class ReuseSessionCreator(tf.train.SessionCreator):
......
......@@ -18,7 +18,7 @@ __all__ = ['SessionInit', 'ChainInit',
class SessionInit(object):
""" Base class for utilities to initialize a (existing) session. """
""" Base class for utilities to load variables to a (existing) session. """
def init(self, sess):
"""
Initialize a session
......@@ -26,9 +26,6 @@ class SessionInit(object):
Args:
sess (tf.Session): the session
"""
self._init(sess)
def _init(self, sess):
self._setup_graph()
self._run_init(sess)
......@@ -236,10 +233,6 @@ class ChainInit(SessionInit):
"""
self.inits = sess_inits
def _init(self, sess):
for i in self.inits:
i.init(sess)
def _setup_graph(self):
for i in self.inits:
i._setup_graph()
......
......@@ -194,16 +194,17 @@ class Trainer(object):
logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
self._config.session_init._setup_graph()
logger.info("Creating the session ...")
self._create_session()
if self.is_chief:
logger.info("Initializing the session ...")
self._config.session_init.init(self.sess)
self._config.session_init._run_init(self.sess)
else:
assert isinstance(self._config.session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!"
if not isinstance(self._config.session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize()
logger.info("Graph Finalized.")
......
......@@ -109,6 +109,8 @@ class Trainer(object):
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
session_init._setup_graph()
logger.info("Creating the session ...")
hooks = self._callbacks.get_hooks()
......@@ -118,20 +120,14 @@ class Trainer(object):
if self.is_chief:
logger.info("Initializing the session ...")
session_init.init(self.sess)
session_init._run_init(self.sess)
else:
assert isinstance(session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!"
if not isinstance(self._config.session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize()
logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
@call_only_once
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999):
"""
......@@ -301,8 +297,7 @@ class SingleCostTrainer(TowerTrainer):
trainer needs are automatically added.
"""
callbacks = callbacks + self._internal_callbacks
Trainer.train(
self,
super(SingleCostTrainer, self).train(
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
......@@ -310,7 +305,7 @@ class SingleCostTrainer(TowerTrainer):
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Responsible for building the main training graph.
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment