Commit 658529d5 authored by Yuxin Wu's avatar Yuxin Wu

add callback <-> hook adapter. fix #147

parent 4c926eb7
...@@ -9,22 +9,12 @@ import traceback ...@@ -9,22 +9,12 @@ import traceback
from .base import Callback from .base import Callback
from .stats import StatPrinter from .stats import StatPrinter
from .hooks import CallbackToHook
from ..utils import logger from ..utils import logger
__all__ = ['Callbacks'] __all__ = ['Callbacks']
class CallbackHook(tf.train.SessionRunHook):
def __init__(self, cb):
self.cb = cb
def before_run(self, ctx):
return self.cb.before_run(ctx)
def after_run(self, ctx, vals):
self.cb.after_run(ctx, vals)
class CallbackTimeLogger(object): class CallbackTimeLogger(object):
def __init__(self): def __init__(self):
self.times = [] self.times = []
...@@ -99,7 +89,7 @@ class Callbacks(Callback): ...@@ -99,7 +89,7 @@ class Callbacks(Callback):
traceback.print_exc() traceback.print_exc()
def get_hooks(self): def get_hooks(self):
return [CallbackHook(cb) for cb in self.cbs] return [CallbackToHook(cb) for cb in self.cbs]
def trigger_step(self): def trigger_step(self):
for cb in self.cbs: for cb in self.cbs:
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: hooks.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Compatible layers between tf.train.SessionRunHook and Callback"""
import tensorflow as tf
from .base import Callback
__all__ = ['CallbackToHook', 'HookToCallback']
class CallbackToHook(tf.train.SessionRunHook):
""" This is only for internal implementation of
before_run/after_run callbacks.
You shouldn't need to use this.
"""
def __init__(self, cb):
self._cb = cb
def before_run(self, ctx):
return self._cb.before_run(ctx)
def after_run(self, ctx, vals):
self._cb.after_run(ctx, vals)
class HookToCallback(Callback):
"""
Make a ``tf.train.SessionRunHook`` into a callback.
Note that the `coord` argument in `after_create_session` will be None.
"""
def __init__(self, hook):
"""
Args:
hook (tf.train.SessionRunHook):
"""
self._hook = hook
def _setup_graph(self):
with tf.name_scope(None): # jump out of the name scope
self._hook.begin()
def _before_train(self):
sess = tf.get_default_session()
# TODO fix coord?
self._hook.after_create_session(sess, None)
def _before_run(self, ctx):
return self._hook.before_run(ctx)
def _after_run(self, ctx, run_values):
self._hook.after_run(ctx, run_values)
def _after_train(self):
self._hook.end()
...@@ -119,15 +119,14 @@ class Trainer(object): ...@@ -119,15 +119,14 @@ class Trainer(object):
describe_model() describe_model()
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
logger.info("Setup summaries ...") logger.info("Setup summaries ...")
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph()) self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
self.summary_op = tf.summary.merge_all() # XXX not good self.summary_op = tf.summary.merge_all() # XXX not good
# create an empty StatHolder # create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Setup callbacks graph ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
self.config.session_init._setup_graph() self.config.session_init._setup_graph()
def after_init(_, __): def after_init(_, __):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment