Commit 61f14083 authored by Yuxin Wu's avatar Yuxin Wu

add "register_callback", so custom trainers can have more control over callbacks&hooks

parent 0a508273
......@@ -9,7 +9,7 @@ from ..models import ModelDesc
from ..utils.develop import log_deprecated
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSession
from ..tfutils.sesscreate import NewSessionCreator
__all__ = ['PredictConfig']
......@@ -28,7 +28,7 @@ class PredictConfig(object):
Args:
model (ModelDesc): the model to use.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSession()`.
session. Defaults to :class:`sesscreate.NewSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all
......@@ -52,9 +52,9 @@ class PredictConfig(object):
if session_creator is None:
if session_config is not None:
log_deprecated("PredictConfig(session_config=)", "Use session_creator instead!", "2017-04-20")
self.session_creator = NewSession(config=session_config)
self.session_creator = NewSessionCreator(config=session_config)
else:
self.session_creator = NewSession(config=get_default_sess_config(0.4))
self.session_creator = NewSessionCreator(config=get_default_sess_config(0.4))
else:
self.session_creator = session_creator
......
......@@ -5,10 +5,10 @@
import tensorflow as tf
__all__ = ['NewSession', 'ReuseSession']
__all__ = ['NewSessionCreator', 'ReuseSessionCreator']
class NewSession(tf.train.SessionCreator):
class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', graph=None, config=None):
"""
Args:
......@@ -22,7 +22,7 @@ class NewSession(tf.train.SessionCreator):
return tf.Session(target=self.target, graph=self.graph, config=self.config)
class ReuseSession(tf.train.SessionCreator):
class ReuseSessionCreator(tf.train.SessionCreator):
def __init__(self, sess):
"""
Args:
......
......@@ -10,11 +10,14 @@ import six
from six.moves import range
import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from .predict import PredictorFactory
from .config import TrainConfig
from ..utils import logger
from ..utils.develop import deprecated, log_deprecated
from ..callbacks import StatHolder
from ..callbacks import StatHolder, Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_scalar_summary
......@@ -57,6 +60,24 @@ class Trainer(object):
self.epoch_num = self.config.starting_epoch - 1
self.local_step = -1
self._callbacks = []
self.register_callback(MaintainStepCounter())
for cb in self.config.callbacks:
self.register_callback(cb)
def register_callback(self, cb):
"""
Use this method before :meth:`Trainer._setup` finishes,
to register a callback to the trainer.
The hooks of the registered callback will be bind to the
`self.hooked_sess` session.
"""
assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!"
self._callbacks.append(cb)
def train(self):
""" Start training """
self.setup()
......@@ -74,7 +95,7 @@ class Trainer(object):
# trigger subclass
self._trigger_epoch()
# trigger callbacks
self.config.callbacks.trigger_epoch()
self._callbacks.trigger_epoch()
self.summary_writer.flush()
def _trigger_epoch(self):
......@@ -126,7 +147,9 @@ class Trainer(object):
self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Setup callbacks graph ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
self.config.session_init._setup_graph()
def after_init(scaffold, sess):
......@@ -140,10 +163,12 @@ class Trainer(object):
self.monitored_sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config),
hooks=self.config.callbacks.get_hooks())
self.hooked_sess = self.monitored_sess # just create an alias
hooks=None)
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
@abstractmethod
def _setup(self):
""" setup Trainer-specific stuff for training"""
......@@ -161,11 +186,10 @@ class Trainer(object):
"""
Run the main training loop.
"""
callbacks = self.config.callbacks
with self.sess.as_default():
self._starting_step = get_global_step_value()
try:
callbacks.before_train()
self._callbacks.before_train()
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num))
......@@ -174,7 +198,7 @@ class Trainer(object):
if self.monitored_sess.should_stop():
return
self.run_step() # implemented by subclass
callbacks.trigger_step()
self._callbacks.trigger_step()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time))
......@@ -186,7 +210,7 @@ class Trainer(object):
except:
raise
finally:
callbacks.after_train()
self._callbacks.after_train()
self.summary_writer.close()
self.monitored_sess.close()
......
......@@ -6,8 +6,7 @@ import tensorflow as tf
from ..callbacks import (
Callbacks, MovingAverageSummary,
StatPrinter, ProgressBar, MergeAllSummaries,
MaintainStepCounter)
StatPrinter, ProgressBar, MergeAllSummaries)
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
......@@ -89,9 +88,8 @@ class TrainConfig(object):
ProgressBar(),
MergeAllSummaries(),
StatPrinter()]
self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks
self.callbacks = callbacks + extra_callbacks
assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks)
self.model = model
assert_type(self.model, ModelDesc)
......
......@@ -155,7 +155,7 @@ class QueueInput(FeedfreeInput):
def setup_training(self, trainer):
self.setup(trainer.model)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque')
......@@ -219,7 +219,7 @@ class BatchQueueInput(FeedfreeInput):
def setup_training(self, trainer):
self.setup(trainer.model)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment