Commit 28ab4781 authored by Yuxin Wu's avatar Yuxin Wu

avoid using MonitoredSession._tf_sess (#194)

parent 083c1c5f
......@@ -20,6 +20,7 @@ from ..utils.develop import deprecated, log_deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
......@@ -118,9 +119,10 @@ class Trainer(object):
# create session
sess_creator = self.config.session_creator
logger.info("Finalize the graph, create the session ...")
self.sess = sess_creator.create_session()
self._monitored_sess = tf.train.MonitoredSession(
session_creator=sess_creator, hooks=None)
self.sess = self._monitored_sess._tf_sess() # expose the underlying session also
session_creator=ReuseSessionCreator(self.sess), hooks=None)
# init session
init_op = tf.global_variables_initializer()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment