Commit 21c00af1 authored by Yuxin Wu's avatar Yuxin Wu

move variable initializer into NewSessionCreator

parent d6258a71
......@@ -5,9 +5,15 @@
import tensorflow as tf
from .common import get_default_sess_config
from ..utils import logger
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
"""
SessionCreator should return a session that is ready to use
(i.e. variables are initialized)
"""
class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', graph=None, config=None):
......@@ -23,7 +29,10 @@ class NewSessionCreator(tf.train.SessionCreator):
self.graph = graph
def create_session(self):
return tf.Session(target=self.target, graph=self.graph, config=self.config)
sess = tf.Session(target=self.target, graph=self.graph, config=self.config)
sess.run(tf.global_variables_initializer())
logger.info("Global variables initialized.")
return sess
class ReuseSessionCreator(tf.train.SessionCreator):
......
......@@ -118,17 +118,17 @@ class Trainer(object):
self._callbacks.setup_graph(weakref.proxy(self))
# create session
logger.info("Finalize the graph, create the session ...")
logger.info("Creating the session ...")
self.sess = self.config.session_creator.create_session()
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
logger.info("Initializing the session ...")
# init session
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
logger.info("Graph variables initialized.")
self.config.session_init.init(self.sess)
self.sess.graph.finalize()
logger.info("Graph Finalized.")
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment