Commit 28ab4781 authored by Yuxin Wu's avatar Yuxin Wu

avoid using MonitoredSession._tf_sess (#194)

parent 083c1c5f
...@@ -20,6 +20,7 @@ from ..utils.develop import deprecated, log_deprecated ...@@ -20,6 +20,7 @@ from ..utils.develop import deprecated, log_deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer'] __all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
...@@ -118,9 +119,10 @@ class Trainer(object): ...@@ -118,9 +119,10 @@ class Trainer(object):
# create session # create session
sess_creator = self.config.session_creator sess_creator = self.config.session_creator
logger.info("Finalize the graph, create the session ...") logger.info("Finalize the graph, create the session ...")
self.sess = sess_creator.create_session()
self._monitored_sess = tf.train.MonitoredSession( self._monitored_sess = tf.train.MonitoredSession(
session_creator=sess_creator, hooks=None) session_creator=ReuseSessionCreator(self.sess), hooks=None)
self.sess = self._monitored_sess._tf_sess() # expose the underlying session also
# init session # init session
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment