Commit ad398637 authored by Yuxin Wu's avatar Yuxin Wu

make Monitor a special kind of Callback, to simplify it's implementation and...

make Monitor a special kind of Callback, to simplify it's implementation and make it more powerful as well
parent b693fe25
...@@ -72,7 +72,7 @@ class Callbacks(Callback): ...@@ -72,7 +72,7 @@ class Callbacks(Callback):
self.cbs = cbs self.cbs = cbs
def _setup_graph(self): def _setup_graph(self):
with tf.name_scope(None): with tf.name_scope(None): # clear the name scope
for cb in self.cbs: for cb in self.cbs:
cb.setup_graph(self.trainer) cb.setup_graph(self.trainer)
......
...@@ -13,24 +13,25 @@ import re ...@@ -13,24 +13,25 @@ import re
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors', __all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter'] 'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter']
class TrainingMonitor(object): class TrainingMonitor(Callback):
""" """
Monitor a training progress, by processing different types of Monitor a training progress, by processing different types of
summary/statistics from trainer. summary/statistics from trainer.
.. document private functions .. document private functions
.. automethod:: _setup .. automethod:: _setup_graph
""" """
def setup(self, trainer): def setup_graph(self, trainer):
self._trainer = trainer self.trainer = trainer
self._setup() self._setup_graph()
def _setup(self): def _setup_graph(self):
""" Override this method to setup the monitor.""" """ Override this method to setup the monitor."""
pass pass
...@@ -51,12 +52,6 @@ class TrainingMonitor(object): ...@@ -51,12 +52,6 @@ class TrainingMonitor(object):
# TODO put other types # TODO put other types
def flush(self):
pass
def close(self):
pass
class NoOpMonitor(TrainingMonitor): class NoOpMonitor(TrainingMonitor):
pass pass
...@@ -67,21 +62,11 @@ class Monitors(TrainingMonitor): ...@@ -67,21 +62,11 @@ class Monitors(TrainingMonitor):
Merge monitors together for trainer to use. Merge monitors together for trainer to use.
""" """
def __init__(self, monitors): def __init__(self, monitors):
# TODO filter by names
self._scalar_history = ScalarHistory() self._scalar_history = ScalarHistory()
self._monitors = monitors + [self._scalar_history] self._monitors = monitors + [self._scalar_history]
def setup(self, trainer): def _setup_graph(self):
for m in self._monitors: self._scalar_history.setup_graph(self.trainer)
m.setup(trainer)
def flush(self):
for m in self._monitors:
m.flush()
def close(self):
for m in self._monitors:
m.close()
def _dispatch_put_summary(self, summary): def _dispatch_put_summary(self, summary):
for m in self._monitors: for m in self._monitors:
...@@ -141,17 +126,16 @@ class TFSummaryWriter(TrainingMonitor): ...@@ -141,17 +126,16 @@ class TFSummaryWriter(TrainingMonitor):
logger.warn("logger directory was not set. Ignore TFSummaryWriter.") logger.warn("logger directory was not set. Ignore TFSummaryWriter.")
return NoOpMonitor() return NoOpMonitor()
def setup(self, trainer): def _setup_graph(self):
super(TFSummaryWriter, self).setup(trainer)
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph()) self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
def put_summary(self, summary): def put_summary(self, summary):
self._writer.add_summary(summary, self._trainer.global_step) self._writer.add_summary(summary, self.trainer.global_step)
def flush(self): def _trigger(self):
self._writer.flush() self._writer.flush()
def close(self): def _after_train(self):
self._writer.close() self._writer.close()
...@@ -166,8 +150,7 @@ class JSONWriter(TrainingMonitor): ...@@ -166,8 +150,7 @@ class JSONWriter(TrainingMonitor):
logger.warn("logger directory was not set. Ignore JSONWriter.") logger.warn("logger directory was not set. Ignore JSONWriter.")
return NoOpMonitor() return NoOpMonitor()
def setup(self, trainer): def _setup_graph(self):
super(JSONWriter, self).setup(trainer)
self._dir = logger.LOG_DIR self._dir = logger.LOG_DIR
self._fname = os.path.join(self._dir, 'stat.json') self._fname = os.path.join(self._dir, 'stat.json')
...@@ -184,11 +167,11 @@ class JSONWriter(TrainingMonitor): ...@@ -184,11 +167,11 @@ class JSONWriter(TrainingMonitor):
self._last_gs = -1 self._last_gs = -1
def put_scalar(self, name, val): def put_scalar(self, name, val):
gs = self._trainer.global_step gs = self.trainer.global_step
if gs != self._last_gs: if gs != self._last_gs:
self._push() self._push()
self._last_gs = gs self._last_gs = gs
self._stat_now['epoch_num'] = self._trainer.epoch_num self._stat_now['epoch_num'] = self.trainer.epoch_num
self._stat_now['global_step'] = gs self._stat_now['global_step'] = gs
self._stat_now[name] = float(val) # TODO will fail for non-numeric self._stat_now[name] = float(val) # TODO will fail for non-numeric
...@@ -208,7 +191,7 @@ class JSONWriter(TrainingMonitor): ...@@ -208,7 +191,7 @@ class JSONWriter(TrainingMonitor):
except IOError: # disk error sometimes.. except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!") logger.exception("Exception in StatHolder.finalize()!")
def flush(self): def _trigger(self):
self._push() self._push()
...@@ -221,7 +204,7 @@ class ScalarPrinter(TrainingMonitor): ...@@ -221,7 +204,7 @@ class ScalarPrinter(TrainingMonitor):
self._whitelist = None self._whitelist = None
self._blacklist = set([]) self._blacklist = set([])
def setup(self, _): def _setup_graph(self):
self._dic = {} self._dic = {}
def put_scalar(self, name, val): def put_scalar(self, name, val):
...@@ -233,7 +216,7 @@ class ScalarPrinter(TrainingMonitor): ...@@ -233,7 +216,7 @@ class ScalarPrinter(TrainingMonitor):
if k not in self._blacklist: if k not in self._blacklist:
logger.info('{}: {:.5g}'.format(k, v)) logger.info('{}: {:.5g}'.format(k, v))
def flush(self): def _trigger(self):
self._print_stat() self._print_stat()
self._dic = {} self._dic = {}
...@@ -242,7 +225,7 @@ class ScalarHistory(TrainingMonitor): ...@@ -242,7 +225,7 @@ class ScalarHistory(TrainingMonitor):
""" """
Only used by monitors internally. Only used by monitors internally.
""" """
def setup(self, _): def _setup_graph(self):
self._dic = defaultdict(list) self._dic = defaultdict(list)
def put_scalar(self, name, val): def put_scalar(self, name, val):
......
...@@ -6,7 +6,7 @@ from pkgutil import iter_modules ...@@ -6,7 +6,7 @@ from pkgutil import iter_modules
import os import os
import os.path import os.path
__all__ = ['monitor'] __all__ = []
def global_import(name): def global_import(name):
...@@ -19,7 +19,7 @@ def global_import(name): ...@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
_SKIP = ['monitor'] _SKIP = []
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
[_CURR_DIR]): [_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
...@@ -14,10 +14,10 @@ from tensorflow.python.training.monitored_session \ ...@@ -14,10 +14,10 @@ from tensorflow.python.training.monitored_session \
from .predict import PredictorFactory from .predict import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
from .monitor import Monitors, TrainingMonitor
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
...@@ -40,7 +40,7 @@ class Trainer(object): ...@@ -40,7 +40,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer. config (TrainConfig): the config used in this trainer.
model (ModelDesc) model (ModelDesc)
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
monitors (Monitors): the monitors monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished. epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch. local_step (int): the number of steps that have finished in the current epoch.
...@@ -86,6 +86,7 @@ class Trainer(object): ...@@ -86,6 +86,7 @@ class Trainer(object):
assert not isinstance(self.monitors, Monitors), \ assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!" "Cannot register more monitors after trainer was setup!"
self.monitors.append(mon) self.monitors.append(mon)
self.register_callback(mon)
def train(self): def train(self):
""" Start training """ """ Start training """
...@@ -106,12 +107,12 @@ class Trainer(object): ...@@ -106,12 +107,12 @@ class Trainer(object):
""" """
self._setup() # subclass will setup the graph self._setup() # subclass will setup the graph
describe_model()
# some final operations that might modify the graph
logger.info("Setup monitors ...")
self.monitors = Monitors(self.monitors) self.monitors = Monitors(self.monitors)
self.monitors.setup(weakref.proxy(self)) self.register_callback(self.monitors)
describe_model()
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
...@@ -168,7 +169,6 @@ class Trainer(object): ...@@ -168,7 +169,6 @@ class Trainer(object):
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._trigger_epoch() self._trigger_epoch()
self._callbacks.trigger_epoch() self._callbacks.trigger_epoch()
self.monitors.flush()
except (StopTraining, tf.errors.OutOfRangeError): except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.") logger.info("Training was stopped.")
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -177,7 +177,6 @@ class Trainer(object): ...@@ -177,7 +177,6 @@ class Trainer(object):
raise raise
finally: finally:
self._callbacks.after_train() self._callbacks.after_train()
self.monitors.close()
self._monitored_sess.close() self._monitored_sess.close()
# Predictor related methods: TODO # Predictor related methods: TODO
......
...@@ -16,7 +16,7 @@ from ..tfutils import (JustCurrentSession, ...@@ -16,7 +16,7 @@ from ..tfutils import (JustCurrentSession,
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.optimizer import apply_grad_processors from ..tfutils.optimizer import apply_grad_processors
from .input_data import InputData from .input_data import InputData
from .monitor import TFSummaryWriter, JSONWriter, ScalarPrinter from ..callbacks.monitor import TFSummaryWriter, JSONWriter, ScalarPrinter
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment