Commit d8627d63 authored by Yuxin Wu's avatar Yuxin Wu

add SessionCreatorAdapter (#191)

parent 16103b17
......@@ -5,7 +5,7 @@
import tensorflow as tf
__all__ = ['NewSessionCreator', 'ReuseSessionCreator']
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
class NewSessionCreator(tf.train.SessionCreator):
......@@ -32,3 +32,19 @@ class ReuseSessionCreator(tf.train.SessionCreator):
def create_session(self):
return self.sess
class SessionCreatorAdapter(tf.train.SessionCreator):
def __init__(self, session_creator, func):
"""
Args:
session_creator (tf.train.SessionCreator): a session creator
func (tf.Session -> tf.Session): takes a session created by
``session_creator``, and return a new session to be returned by ``self.create_session``
"""
self._creator = session_creator
self._func = func
def create_session(self):
sess = self._creator.create_session()
return self._func(sess)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment