Commit 61f14083 authored by Yuxin Wu's avatar Yuxin Wu

add "register_callback", so custom trainers can have more control over callbacks&hooks

parent 0a508273
...@@ -9,7 +9,7 @@ from ..models import ModelDesc ...@@ -9,7 +9,7 @@ from ..models import ModelDesc
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSession from ..tfutils.sesscreate import NewSessionCreator
__all__ = ['PredictConfig'] __all__ = ['PredictConfig']
...@@ -28,7 +28,7 @@ class PredictConfig(object): ...@@ -28,7 +28,7 @@ class PredictConfig(object):
Args: Args:
model (ModelDesc): the model to use. model (ModelDesc): the model to use.
session_creator (tf.train.SessionCreator): how to create the session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSession()`. session. Defaults to :class:`sesscreate.NewSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session. session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing. Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all input_names (list): a list of input tensor names. Defaults to all
...@@ -52,9 +52,9 @@ class PredictConfig(object): ...@@ -52,9 +52,9 @@ class PredictConfig(object):
if session_creator is None: if session_creator is None:
if session_config is not None: if session_config is not None:
log_deprecated("PredictConfig(session_config=)", "Use session_creator instead!", "2017-04-20") log_deprecated("PredictConfig(session_config=)", "Use session_creator instead!", "2017-04-20")
self.session_creator = NewSession(config=session_config) self.session_creator = NewSessionCreator(config=session_config)
else: else:
self.session_creator = NewSession(config=get_default_sess_config(0.4)) self.session_creator = NewSessionCreator(config=get_default_sess_config(0.4))
else: else:
self.session_creator = session_creator self.session_creator = session_creator
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
import tensorflow as tf import tensorflow as tf
__all__ = ['NewSession', 'ReuseSession'] __all__ = ['NewSessionCreator', 'ReuseSessionCreator']
class NewSession(tf.train.SessionCreator): class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', graph=None, config=None): def __init__(self, target='', graph=None, config=None):
""" """
Args: Args:
...@@ -22,7 +22,7 @@ class NewSession(tf.train.SessionCreator): ...@@ -22,7 +22,7 @@ class NewSession(tf.train.SessionCreator):
return tf.Session(target=self.target, graph=self.graph, config=self.config) return tf.Session(target=self.target, graph=self.graph, config=self.config)
class ReuseSession(tf.train.SessionCreator): class ReuseSessionCreator(tf.train.SessionCreator):
def __init__(self, sess): def __init__(self, sess):
""" """
Args: Args:
......
...@@ -10,11 +10,14 @@ import six ...@@ -10,11 +10,14 @@ import six
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from .predict import PredictorFactory from .predict import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated, log_deprecated from ..utils.develop import deprecated, log_deprecated
from ..callbacks import StatHolder from ..callbacks import StatHolder, Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_scalar_summary from ..tfutils.summary import create_scalar_summary
...@@ -57,6 +60,24 @@ class Trainer(object): ...@@ -57,6 +60,24 @@ class Trainer(object):
self.epoch_num = self.config.starting_epoch - 1 self.epoch_num = self.config.starting_epoch - 1
self.local_step = -1 self.local_step = -1
self._callbacks = []
self.register_callback(MaintainStepCounter())
for cb in self.config.callbacks:
self.register_callback(cb)
def register_callback(self, cb):
"""
Use this method before :meth:`Trainer._setup` finishes,
to register a callback to the trainer.
The hooks of the registered callback will be bind to the
`self.hooked_sess` session.
"""
assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!"
self._callbacks.append(cb)
def train(self): def train(self):
""" Start training """ """ Start training """
self.setup() self.setup()
...@@ -74,7 +95,7 @@ class Trainer(object): ...@@ -74,7 +95,7 @@ class Trainer(object):
# trigger subclass # trigger subclass
self._trigger_epoch() self._trigger_epoch()
# trigger callbacks # trigger callbacks
self.config.callbacks.trigger_epoch() self._callbacks.trigger_epoch()
self.summary_writer.flush() self.summary_writer.flush()
def _trigger_epoch(self): def _trigger_epoch(self):
...@@ -126,7 +147,9 @@ class Trainer(object): ...@@ -126,7 +147,9 @@ class Trainer(object):
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
self.config.session_init._setup_graph() self.config.session_init._setup_graph()
def after_init(scaffold, sess): def after_init(scaffold, sess):
...@@ -140,10 +163,12 @@ class Trainer(object): ...@@ -140,10 +163,12 @@ class Trainer(object):
self.monitored_sess = tf.train.MonitoredSession( self.monitored_sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator( session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config), scaffold=scaffold, config=self.config.session_config),
hooks=self.config.callbacks.get_hooks()) hooks=None)
self.hooked_sess = self.monitored_sess # just create an alias
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
@abstractmethod @abstractmethod
def _setup(self): def _setup(self):
""" setup Trainer-specific stuff for training""" """ setup Trainer-specific stuff for training"""
...@@ -161,11 +186,10 @@ class Trainer(object): ...@@ -161,11 +186,10 @@ class Trainer(object):
""" """
Run the main training loop. Run the main training loop.
""" """
callbacks = self.config.callbacks
with self.sess.as_default(): with self.sess.as_default():
self._starting_step = get_global_step_value() self._starting_step = get_global_step_value()
try: try:
callbacks.before_train() self._callbacks.before_train()
for self.epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self.epoch_num))
...@@ -174,7 +198,7 @@ class Trainer(object): ...@@ -174,7 +198,7 @@ class Trainer(object):
if self.monitored_sess.should_stop(): if self.monitored_sess.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
callbacks.trigger_step() self._callbacks.trigger_step()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self.epoch_num, self.global_step, time.time() - start_time))
...@@ -186,7 +210,7 @@ class Trainer(object): ...@@ -186,7 +210,7 @@ class Trainer(object):
except: except:
raise raise
finally: finally:
callbacks.after_train() self._callbacks.after_train()
self.summary_writer.close() self.summary_writer.close()
self.monitored_sess.close() self.monitored_sess.close()
......
...@@ -6,8 +6,7 @@ import tensorflow as tf ...@@ -6,8 +6,7 @@ import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, Callbacks, MovingAverageSummary,
StatPrinter, ProgressBar, MergeAllSummaries, StatPrinter, ProgressBar, MergeAllSummaries)
MaintainStepCounter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -89,9 +88,8 @@ class TrainConfig(object): ...@@ -89,9 +88,8 @@ class TrainConfig(object):
ProgressBar(), ProgressBar(),
MergeAllSummaries(), MergeAllSummaries(),
StatPrinter()] StatPrinter()]
self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks self.callbacks = callbacks + extra_callbacks
assert_type(self.callbacks, list) assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks)
self.model = model self.model = model
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
......
...@@ -155,7 +155,7 @@ class QueueInput(FeedfreeInput): ...@@ -155,7 +155,7 @@ class QueueInput(FeedfreeInput):
def setup_training(self, trainer): def setup_training(self, trainer):
self.setup(trainer.model) self.setup(trainer.model)
trainer.config.callbacks.append(StartProcOrThread(self.thread)) trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self): def get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
...@@ -219,7 +219,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -219,7 +219,7 @@ class BatchQueueInput(FeedfreeInput):
def setup_training(self, trainer): def setup_training(self, trainer):
self.setup(trainer.model) self.setup(trainer.model)
trainer.config.callbacks.append(StartProcOrThread(self.thread)) trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self): def get_input_tensors(self):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque') ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment