Commit 818e3faf authored by Yuxin Wu's avatar Yuxin Wu

clean-ups in callbacks

parent 0d20cb3d
......@@ -12,6 +12,7 @@ from ..utils import *
__all__ = ['Callbacks']
# --- Test-Callback related stuff seems not very useful.
@contextmanager
def create_test_graph(trainer):
model = trainer.model
......@@ -31,33 +32,6 @@ def create_test_session(trainer):
with tf.Session() as sess:
yield sess
class CallbackTimeLogger(object):
def __init__(self):
self.times = []
self.tot = 0
def add(self, name, time):
self.tot += time
self.times.append((name, time))
@contextmanager
def timed_callback(self, name):
s = time.time()
yield
self.add(name, time.time() - s)
def log(self):
""" log the time of some heavy callbacks """
if self.tot < 3:
return
msgs = []
for name, t in self.times:
if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{:.3f}sec".format(name, t))
logger.info(
"Callbacks took {:.3f} sec in total. {}".format(
self.tot, '; '.join(msgs)))
class TestCallbackContext(object):
"""
A class holding the context needed for running TestCallback
......@@ -91,6 +65,34 @@ class TestCallbackContext(object):
def test_context(self):
with self.graph.as_default(), self.sess.as_default():
yield
# ---
class CallbackTimeLogger(object):
def __init__(self):
self.times = []
self.tot = 0
def add(self, name, time):
self.tot += time
self.times.append((name, time))
@contextmanager
def timed_callback(self, name):
s = time.time()
yield
self.add(name, time.time() - s)
def log(self):
""" log the time of some heavy callbacks """
if self.tot < 3:
return
msgs = []
for name, t in self.times:
if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{:.3f}sec".format(name, t))
logger.info(
"Callbacks took {:.3f} sec in total. {}".format(
self.tot, '; '.join(msgs)))
class Callbacks(Callback):
"""
......
......@@ -13,7 +13,7 @@ from ..utils import *
from ..utils.stat import *
from ..tfutils import *
from ..tfutils.summary import *
from .base import Callback, TestCallbackType
from .base import Callback
__all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
......@@ -63,7 +63,6 @@ class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
#type = TestCallbackType()
def __init__(self, ds, vcs):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment