Commit c99d013a authored by Yuxin Wu's avatar Yuxin Wu

Allow sesscreate to finalize the graph.

parent 93a177bf
...@@ -68,8 +68,7 @@ class ModelSaver(Callback): ...@@ -68,8 +68,7 @@ class ModelSaver(Callback):
keep_checkpoint_every_n_hours=self._keep_every_n_hours, keep_checkpoint_every_n_hours=self._keep_every_n_hours,
write_version=tf.train.SaverDef.V2, write_version=tf.train.SaverDef.V2,
save_relative_paths=True) save_relative_paths=True)
# Don't know how it can be useful, # Scaffold will call saver.build from this collection
# but since there is a predefined key, why not use it?
tf.add_to_collection(tf.GraphKeys.SAVERS, self.saver) tf.add_to_collection(tf.GraphKeys.SAVERS, self.saver)
def _before_train(self): def _before_train(self):
......
...@@ -5,24 +5,27 @@ ...@@ -5,24 +5,27 @@
import tensorflow as tf import tensorflow as tf
from .common import get_default_sess_config from .common import get_default_sess_config
from ..utils import logger
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter'] __all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
""" """
SessionCreator should return a session that is ready to use A SessionCreator should:
(i.e. variables are initialized) (optionally) finalize the graph
create the session
initialize all variables
return a session that is ready to use
""" """
class NewSessionCreator(tf.train.SessionCreator): class NewSessionCreator(tf.train.ChiefSessionCreator):
def __init__(self, target='', graph=None, config=None): def __init__(self, target='', graph=None, config=None):
""" """
Args: Args:
target, graph, config: same as :meth:`Session.__init__()`. target, graph, config: same as :meth:`Session.__init__()`.
config: defaults to :func:`tfutils.get_default_sess_config()` config: defaults to :func:`tfutils.get_default_sess_config()`
""" """
self.target = target assert graph is None
if config is None: if config is None:
# distributd trainer doesn't support user-provided config # distributd trainer doesn't support user-provided config
# we set this attribute so that they can check # we set this attribute so that they can check
...@@ -31,15 +34,7 @@ class NewSessionCreator(tf.train.SessionCreator): ...@@ -31,15 +34,7 @@ class NewSessionCreator(tf.train.SessionCreator):
else: else:
self.user_provided_config = True self.user_provided_config = True
self.config = config super(NewSessionCreator, self).__init__(master=target, config=config)
self.graph = graph
def create_session(self):
sess = tf.Session(target=self.target, graph=self.graph, config=self.config)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
logger.info("Global and local variables initialized.")
return sess
class ReuseSessionCreator(tf.train.SessionCreator): class ReuseSessionCreator(tf.train.SessionCreator):
......
...@@ -18,7 +18,7 @@ __all__ = ['SessionInit', 'ChainInit', ...@@ -18,7 +18,7 @@ __all__ = ['SessionInit', 'ChainInit',
class SessionInit(object): class SessionInit(object):
""" Base class for utilities to initialize a (existing) session. """ """ Base class for utilities to load variables to a (existing) session. """
def init(self, sess): def init(self, sess):
""" """
Initialize a session Initialize a session
...@@ -26,9 +26,6 @@ class SessionInit(object): ...@@ -26,9 +26,6 @@ class SessionInit(object):
Args: Args:
sess (tf.Session): the session sess (tf.Session): the session
""" """
self._init(sess)
def _init(self, sess):
self._setup_graph() self._setup_graph()
self._run_init(sess) self._run_init(sess)
...@@ -236,10 +233,6 @@ class ChainInit(SessionInit): ...@@ -236,10 +233,6 @@ class ChainInit(SessionInit):
""" """
self.inits = sess_inits self.inits = sess_inits
def _init(self, sess):
for i in self.inits:
i.init(sess)
def _setup_graph(self): def _setup_graph(self):
for i in self.inits: for i in self.inits:
i._setup_graph() i._setup_graph()
......
...@@ -194,16 +194,17 @@ class Trainer(object): ...@@ -194,16 +194,17 @@ class Trainer(object):
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
self._config.session_init._setup_graph()
logger.info("Creating the session ...") logger.info("Creating the session ...")
self._create_session() self._create_session()
if self.is_chief: if self.is_chief:
logger.info("Initializing the session ...") logger.info("Initializing the session ...")
self._config.session_init.init(self.sess) self._config.session_init._run_init(self.sess)
else: else:
assert isinstance(self._config.session_init, JustCurrentSession), \ if not isinstance(self._config.session_init, JustCurrentSession):
"session_init is only valid for chief worker session!" logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize() self.sess.graph.finalize()
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
......
...@@ -109,6 +109,8 @@ class Trainer(object): ...@@ -109,6 +109,8 @@ class Trainer(object):
Initialize self.sess and self.hooked_sess. Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup. Must be called after callbacks are setup.
""" """
session_init._setup_graph()
logger.info("Creating the session ...") logger.info("Creating the session ...")
hooks = self._callbacks.get_hooks() hooks = self._callbacks.get_hooks()
...@@ -118,20 +120,14 @@ class Trainer(object): ...@@ -118,20 +120,14 @@ class Trainer(object):
if self.is_chief: if self.is_chief:
logger.info("Initializing the session ...") logger.info("Initializing the session ...")
session_init.init(self.sess) session_init._run_init(self.sess)
else: else:
assert isinstance(session_init, JustCurrentSession), \ if not isinstance(self._config.session_init, JustCurrentSession):
"session_init is only valid for chief worker session!" logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize() self.sess.graph.finalize()
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
@call_only_once @call_only_once
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999): def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999):
""" """
...@@ -301,8 +297,7 @@ class SingleCostTrainer(TowerTrainer): ...@@ -301,8 +297,7 @@ class SingleCostTrainer(TowerTrainer):
trainer needs are automatically added. trainer needs are automatically added.
""" """
callbacks = callbacks + self._internal_callbacks callbacks = callbacks + self._internal_callbacks
Trainer.train( super(SingleCostTrainer, self).train(
self,
callbacks, monitors, callbacks, monitors,
session_creator, session_init, session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch) steps_per_epoch, starting_epoch, max_epoch)
...@@ -310,7 +305,7 @@ class SingleCostTrainer(TowerTrainer): ...@@ -310,7 +305,7 @@ class SingleCostTrainer(TowerTrainer):
@call_only_once @call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn): def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
""" """
Responsible for building the main training graph. Responsible for building the main training graph for single-cost training.
Args: Args:
inputs_desc ([InputDesc]): inputs_desc ([InputDesc]):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment