Commit ad398637 authored by Yuxin Wu's avatar Yuxin Wu

make Monitor a special kind of Callback, to simplify it's implementation and...

make Monitor a special kind of Callback, to simplify it's implementation and make it more powerful as well
parent b693fe25
......@@ -72,7 +72,7 @@ class Callbacks(Callback):
self.cbs = cbs
def _setup_graph(self):
with tf.name_scope(None):
with tf.name_scope(None): # clear the name scope
for cb in self.cbs:
cb.setup_graph(self.trainer)
......
......@@ -13,24 +13,25 @@ import re
import tensorflow as tf
from ..utils import logger
from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter']
class TrainingMonitor(object):
class TrainingMonitor(Callback):
"""
Monitor a training progress, by processing different types of
summary/statistics from trainer.
.. document private functions
.. automethod:: _setup
.. automethod:: _setup_graph
"""
def setup(self, trainer):
self._trainer = trainer
self._setup()
def setup_graph(self, trainer):
self.trainer = trainer
self._setup_graph()
def _setup(self):
def _setup_graph(self):
""" Override this method to setup the monitor."""
pass
......@@ -51,12 +52,6 @@ class TrainingMonitor(object):
# TODO put other types
def flush(self):
pass
def close(self):
pass
class NoOpMonitor(TrainingMonitor):
pass
......@@ -67,21 +62,11 @@ class Monitors(TrainingMonitor):
Merge monitors together for trainer to use.
"""
def __init__(self, monitors):
# TODO filter by names
self._scalar_history = ScalarHistory()
self._monitors = monitors + [self._scalar_history]
def setup(self, trainer):
for m in self._monitors:
m.setup(trainer)
def flush(self):
for m in self._monitors:
m.flush()
def close(self):
for m in self._monitors:
m.close()
def _setup_graph(self):
self._scalar_history.setup_graph(self.trainer)
def _dispatch_put_summary(self, summary):
for m in self._monitors:
......@@ -141,17 +126,16 @@ class TFSummaryWriter(TrainingMonitor):
logger.warn("logger directory was not set. Ignore TFSummaryWriter.")
return NoOpMonitor()
def setup(self, trainer):
super(TFSummaryWriter, self).setup(trainer)
def _setup_graph(self):
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
def put_summary(self, summary):
self._writer.add_summary(summary, self._trainer.global_step)
self._writer.add_summary(summary, self.trainer.global_step)
def flush(self):
def _trigger(self):
self._writer.flush()
def close(self):
def _after_train(self):
self._writer.close()
......@@ -166,8 +150,7 @@ class JSONWriter(TrainingMonitor):
logger.warn("logger directory was not set. Ignore JSONWriter.")
return NoOpMonitor()
def setup(self, trainer):
super(JSONWriter, self).setup(trainer)
def _setup_graph(self):
self._dir = logger.LOG_DIR
self._fname = os.path.join(self._dir, 'stat.json')
......@@ -184,11 +167,11 @@ class JSONWriter(TrainingMonitor):
self._last_gs = -1
def put_scalar(self, name, val):
gs = self._trainer.global_step
gs = self.trainer.global_step
if gs != self._last_gs:
self._push()
self._last_gs = gs
self._stat_now['epoch_num'] = self._trainer.epoch_num
self._stat_now['epoch_num'] = self.trainer.epoch_num
self._stat_now['global_step'] = gs
self._stat_now[name] = float(val) # TODO will fail for non-numeric
......@@ -208,7 +191,7 @@ class JSONWriter(TrainingMonitor):
except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!")
def flush(self):
def _trigger(self):
self._push()
......@@ -221,7 +204,7 @@ class ScalarPrinter(TrainingMonitor):
self._whitelist = None
self._blacklist = set([])
def setup(self, _):
def _setup_graph(self):
self._dic = {}
def put_scalar(self, name, val):
......@@ -233,7 +216,7 @@ class ScalarPrinter(TrainingMonitor):
if k not in self._blacklist:
logger.info('{}: {:.5g}'.format(k, v))
def flush(self):
def _trigger(self):
self._print_stat()
self._dic = {}
......@@ -242,7 +225,7 @@ class ScalarHistory(TrainingMonitor):
"""
Only used by monitors internally.
"""
def setup(self, _):
def _setup_graph(self):
self._dic = defaultdict(list)
def put_scalar(self, name, val):
......
......@@ -6,7 +6,7 @@ from pkgutil import iter_modules
import os
import os.path
__all__ = ['monitor']
__all__ = []
def global_import(name):
......@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__)
_SKIP = ['monitor']
_SKIP = []
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
......@@ -14,10 +14,10 @@ from tensorflow.python.training.monitored_session \
from .predict import PredictorFactory
from .config import TrainConfig
from .monitor import Monitors, TrainingMonitor
from ..utils import logger
from ..utils.develop import deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_model
from ..tfutils.sesscreate import ReuseSessionCreator
......@@ -40,7 +40,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
monitors (Monitors): the monitors
monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch.
......@@ -86,6 +86,7 @@ class Trainer(object):
assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
self.monitors.append(mon)
self.register_callback(mon)
def train(self):
""" Start training """
......@@ -106,12 +107,12 @@ class Trainer(object):
"""
self._setup() # subclass will setup the graph
describe_model()
# some final operations that might modify the graph
logger.info("Setup monitors ...")
self.monitors = Monitors(self.monitors)
self.monitors.setup(weakref.proxy(self))
self.register_callback(self.monitors)
describe_model()
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
......@@ -168,7 +169,6 @@ class Trainer(object):
# trigger epoch outside the timing region.
self._trigger_epoch()
self._callbacks.trigger_epoch()
self.monitors.flush()
except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.")
except KeyboardInterrupt:
......@@ -177,7 +177,6 @@ class Trainer(object):
raise
finally:
self._callbacks.after_train()
self.monitors.close()
self._monitored_sess.close()
# Predictor related methods: TODO
......
......@@ -16,7 +16,7 @@ from ..tfutils import (JustCurrentSession,
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.optimizer import apply_grad_processors
from .input_data import InputData
from .monitor import TFSummaryWriter, JSONWriter, ScalarPrinter
from ..callbacks.monitor import TFSummaryWriter, JSONWriter, ScalarPrinter
__all__ = ['TrainConfig']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment