Commit 2eadc59d authored by Yuxin Wu's avatar Yuxin Wu

Add missing SessionCreatorAdapter

parent 55f640f7
......@@ -38,5 +38,5 @@ Keras does not respect variable scopes or variable
collections, which contradicts with tensorpack trainers.
Therefore Keras support is __experimental__.
These simple examples can run within tensorpack smoothly, but note that a future version
of Keras may break them (unlikely, though).
These simple examples can run within tensorpack smoothly, but note that a future
version of Keras or a complicated model may break them (unlikely, though).
......@@ -8,7 +8,7 @@ from ..tfutils.common import tfv1
from ..utils import logger
from .common import get_default_sess_config
__all__ = ['NewSessionCreator', 'ReuseSessionCreator']
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
"""
A SessionCreator should:
......@@ -49,6 +49,9 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
class ReuseSessionCreator(tfv1.train.SessionCreator):
"""
Returns an existing session.
"""
def __init__(self, sess):
"""
Args:
......@@ -58,3 +61,22 @@ class ReuseSessionCreator(tfv1.train.SessionCreator):
def create_session(self):
return self.sess
class SessionCreatorAdapter(tfv1.train.SessionCreator):
"""
Apply a function on the output of a SessionCreator. Can be used to create a debug session.
"""
def __init__(self, session_creator, func):
"""
Args:
session_creator (tf.train.SessionCreator): a session creator
func (tf.Session -> tf.Session): takes a session created by
``session_creator``, and return a new session to be returned by ``self.create_session``
"""
self._creator = session_creator
self._func = func
def create_session(self):
sess = self._creator.create_session()
return self._func(sess)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment