Commit 0f4eda16 authored by Yuxin Wu's avatar Yuxin Wu

use Monitors as the backend for all the summaries and stats

parent 61f14083
......@@ -146,12 +146,12 @@ def get_config():
RunOp(lambda: M.reset_lstm_state()),
CallbackFactory(
trigger_epoch=lambda self:
[self.trainer.add_scalar_summary(
[self.trainer.monitors.put(
'validation_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN)),
self.trainer.add_scalar_summary(
np.exp(self.trainer.monitors.get_latest('validation_cost') / SEQ_LEN)),
self.trainer.monitors.put(
'test_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('test_cost') / SEQ_LEN))]
np.exp(self.trainer.monitors.get_latest('test_cost') / SEQ_LEN))]
),
],
max_epoch=70,
......
......@@ -152,7 +152,6 @@ def get_config(model, algorithm_name):
MovingAverageSummary(),
ProgressBar(extra_display),
MergeAllSummaries(),
StatPrinter()
],
max_epoch=20,
)
......
......@@ -63,7 +63,7 @@ def summary_inferencer(trainer, infs):
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
trainer.add_scalar_summary(k, v)
trainer.monitors.put(k, v)
class InferenceRunner(Triggerable):
......
......@@ -318,8 +318,7 @@ class StatMonitorParamSetter(HyperParamSetter):
self.last_changed_epoch = 0
def _get_value_to_set(self):
holder = self.trainer.stat_holder
hist = holder.get_stat_history(self.stat_name)
hist = self.trainer.monitors.get_history(self.stat_name)
if len(hist) < self.last_k + 1 or \
self.epoch_num - self.last_changed_epoch < self.last_k:
return None
......
......@@ -98,7 +98,7 @@ class MinSaver(Triggerable):
def _get_stat(self):
try:
v = self.trainer.stat_holder.get_stat_now(self.monitor_stat)
v = self.trainer.monitors.get_latest(self.monitor_stat)
except KeyError:
v = None
return v
......
......@@ -3,148 +3,22 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import operator
import json
from .base import Triggerable
from ..utils import logger
from ..utils.develop import log_deprecated
__all__ = ['StatHolder', 'StatPrinter', 'SendStat']
class StatHolder(object):
"""
A holder to keep all statistics aside from tensorflow events.
"""
def __init__(self, log_dir):
"""
Args:
log_dir(str): directory to save the stats.
"""
self.set_print_tag([])
self.blacklist_tag = set()
self.stat_now = {}
self.log_dir = log_dir
self.filename = os.path.join(log_dir, 'stat.json')
if os.path.isfile(self.filename):
# TODO make a backup first?
logger.info("Found stats at {}, will append to it.".format(self.filename))
with open(self.filename) as f:
self.stat_history = json.load(f)
else:
self.stat_history = []
# global step of the current list of stat
self._current_gs = -1
def add_stat(self, k, v, global_step, epoch_num):
"""
Add a stat.
"""
if global_step != self._current_gs:
self._push()
self._current_gs = global_step
self.stat_now['epoch_num'] = epoch_num
self.stat_now['global_step'] = global_step
self.stat_now[k] = float(v)
def set_print_tag(self, print_tag):
"""
Set name of stats to print.
Args:
print_tag: a collection of string.
"""
self.print_tag = None if print_tag is None else set(print_tag)
def add_blacklist_tag(self, blacklist_tag):
""" Disable printing for some tags
Args:
blacklist_tag: a collection of string.
"""
self.blacklist_tag |= set(blacklist_tag)
def get_stat_now(self, key):
"""
Return the value of a stat in the current epoch.
Raises:
KeyError if the key hasn't been added in this epoch.
"""
return self.stat_now[key]
def get_stat_history(self, key):
"""
Returns:
list: all history of a stat. Empty if there is not history of this name.
"""
ret = []
for h in self.stat_history:
v = h.get(key, None)
if v is not None:
ret.append(v)
v = self.stat_now.get(key, None)
if v is not None:
ret.append(v)
return ret
def finalize(self):
"""
Print and write stats to disk.
This method is idempotent.
"""
self._print_stat()
self._push()
def _push(self):
""" Note that this method is idempotent"""
if len(self.stat_now):
self.stat_history.append(self.stat_now)
self.stat_now = {}
self._write_stat()
def _print_stat(self):
for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)):
if self.print_tag is None or k in self.print_tag:
if k not in self.blacklist_tag:
logger.info('{}: {:.5g}'.format(k, v))
def _write_stat(self):
tmp_filename = self.filename + '.tmp'
try:
with open(tmp_filename, 'w') as f:
json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename)
except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!")
__all__ = ['StatPrinter', 'SendStat']
class StatPrinter(Triggerable):
"""
A callback to control what stats to print. Enable by default to print
everything in trainer.stat_holder.
"""
def __init__(self, print_tag=None):
"""
Args:
print_tag: a list of stat names to print.
If None, will print all scalar tags.
"""
self.print_tag = print_tag
def _before_train(self):
self._stat_holder = self.trainer.stat_holder
self._stat_holder.set_print_tag(self.print_tag)
self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
def _trigger(self):
self._stat_holder.finalize()
log_deprecated("StatPrinter",
"No need to add StatPrinter to callbacks anymore!",
"2017-03-26")
# TODO make it into monitor?
class SendStat(Triggerable):
"""
Execute a command with some specific stats.
......@@ -173,8 +47,8 @@ class SendStat(Triggerable):
self.stats = stats
def _trigger(self):
holder = self.trainer.stat_holder
v = {k: holder.get_stat_now(k) for k in self.stats}
M = self.trainer.monitors
v = {k: M.get_latest(k) for k in self.stats}
cmd = self.command.format(**v)
ret = os.system(cmd)
if ret != 0:
......
......@@ -68,9 +68,9 @@ class MergeAllSummaries(Callback):
summary = run_values.results
if summary is None:
return
self.trainer.add_summary(summary)
self.trainer.monitors.put_summary(summary)
def _trigger_epoch(self):
if self._run_alone:
summary = self.summary_op.eval()
self.trainer.add_summary(summary)
self.trainer.monitors.put_summary(summary)
......@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import ABCMeta, abstractmethod
import re
import time
import weakref
import six
......@@ -15,12 +14,12 @@ from tensorflow.python.training.monitored_session \
from .predict import PredictorFactory
from .config import TrainConfig
from .monitor import Monitors, TrainingMonitor
from ..utils import logger
from ..utils.develop import deprecated, log_deprecated
from ..callbacks import StatHolder, Callback, Callbacks, MaintainStepCounter
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_scalar_summary
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
......@@ -40,9 +39,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
monitors (Monitors): the monitors
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch.
......@@ -65,6 +62,8 @@ class Trainer(object):
for cb in self.config.callbacks:
self.register_callback(cb)
self.monitors = config.monitors
def register_callback(self, cb):
"""
Use this method before :meth:`Trainer._setup` finishes,
......@@ -78,6 +77,12 @@ class Trainer(object):
"Cannot register more callbacks after trainer was setup!"
self._callbacks.append(cb)
def register_monitor(self, mon):
assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
self.monitors.append(mon)
def train(self):
""" Start training """
self.setup()
......@@ -88,48 +93,9 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
def trigger_epoch(self):
"""
Called after each epoch.
"""
# trigger subclass
self._trigger_epoch()
# trigger callbacks
self._callbacks.trigger_epoch()
self.summary_writer.flush()
def _trigger_epoch(self):
pass
def add_summary(self, summary):
"""
Add summary to ``self.summary_writer``, and also
add scalar summary to ``self.stat_holder``.
Args:
summary (tf.Summary or str): a summary object, or a str which will
be interpreted as a serialized tf.Summary protobuf.
"""
if isinstance(summary, six.binary_type):
summary = tf.Summary.FromString(summary)
assert isinstance(summary, tf.Summary), type(summary)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(
val.tag, val.simple_value,
self.global_step, self.epoch_num)
self.summary_writer.add_summary(summary, get_global_step_value())
def add_scalar_summary(self, name, val):
"""
Add a scalar summary to both TF events file and StatHolder.
"""
self.add_summary(create_scalar_summary(name, val))
def setup(self):
"""
Setup the trainer and be ready for the main loop.
......@@ -141,10 +107,9 @@ class Trainer(object):
describe_model()
# some final operations that might modify the graph
logger.info("Setup summaries ...")
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Setup monitors ...")
self.monitors = Monitors(self.monitors)
self.monitors.setup(weakref.proxy(self))
logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks)
......@@ -202,7 +167,10 @@ class Trainer(object):
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time))
self.trigger_epoch() # trigger epoch outside the timing region.
# trigger epoch outside the timing region.
self._trigger_epoch()
self._callbacks.trigger_epoch()
self.monitors.flush()
except StopTraining:
logger.info("Training was stopped.")
except KeyboardInterrupt:
......@@ -211,9 +179,10 @@ class Trainer(object):
raise
finally:
self._callbacks.after_train()
self.summary_writer.close()
self.monitors.close()
self.monitored_sess.close()
# Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0):
"""
Args:
......
......@@ -6,7 +6,7 @@ import tensorflow as tf
from ..callbacks import (
Callbacks, MovingAverageSummary,
StatPrinter, ProgressBar, MergeAllSummaries)
ProgressBar, MergeAllSummaries)
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
......@@ -15,6 +15,7 @@ from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..tfutils.optimizer import apply_grad_processors
from .input_data import InputData
from .monitor import TFSummaryWriter, JSONWriter, ScalarPrinter
__all__ = ['TrainConfig']
......@@ -24,11 +25,12 @@ class TrainConfig(object):
Config for trainer.
"""
def __init__(self, dataflow=None, data=None,
def __init__(self,
dataflow=None, data=None,
model=None,
callbacks=None, extra_callbacks=None,
session_config=get_default_sess_config(),
session_init=None,
monitors=None,
session_config=get_default_sess_config(), session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0],
**kwargs):
......@@ -41,10 +43,10 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), StatPrinter()]``. The list of
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
Note that ``StatPrinter`` should be the last one to be able to print
stats generated by other callbacks.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
......@@ -86,11 +88,14 @@ class TrainConfig(object):
extra_callbacks = [
MovingAverageSummary(),
ProgressBar(),
MergeAllSummaries(),
StatPrinter()]
MergeAllSummaries()]
self.callbacks = callbacks + extra_callbacks
assert_type(self.callbacks, list)
if monitors is None:
monitors = [TFSummaryWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors
self.model = model
assert_type(self.model, ModelDesc)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: monitor.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import operator
from collections import defaultdict
import six
import json
import re
import tensorflow as tf
from ..utils import logger
__all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter']
class TrainingMonitor(object):
def setup(self, trainer):
self._trainer = trainer
def put_summary(self, summary):
pass
def put(self, name, val): # TODO split by types?
pass
def flush(self):
pass
def close(self):
pass
class Monitors(TrainingMonitor):
def __init__(self, monitors):
# TODO filter by names
self._scalar_history = ScalarHistory()
self._monitors = monitors + [self._scalar_history]
def setup(self, trainer):
for m in self._monitors:
m.setup(trainer)
def flush(self):
for m in self._monitors:
m.flush()
def close(self):
for m in self._monitors:
m.close()
def _dispatch_put_summary(self, summary):
for m in self._monitors:
m.put_summary(summary)
def _dispatch_put(self, name, val):
for m in self._monitors:
m.put(name, val)
def put_summary(self, summary):
if isinstance(summary, six.binary_type):
summary = tf.Summary.FromString(summary)
assert isinstance(summary, tf.Summary), type(summary)
self._dispatch_put_summary(summary)
# TODO other types
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self._dispatch_put(val.tag, val.simple_value)
def put(self, name, val):
val = float(val) # TODO only support numeric for now
self._dispatch_put(name, val)
s = tf.Summary()
s.value.add(tag=name, simple_value=val)
self._dispatch_put_summary(s)
def get_latest(self, name):
return self._scalar_history.get_latest(name)
def get_history(self, name):
return self._scalar_history.get_history(name)
class TFSummaryWriter(TrainingMonitor):
def setup(self, trainer):
super(TFSummaryWriter, self).setup(trainer)
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
def put_summary(self, summary):
self._writer.add_summary(summary, self._trainer.global_step)
def flush(self):
self._writer.flush()
def close(self):
self._writer.close()
class JSONWriter(TrainingMonitor):
def setup(self, trainer):
super(JSONWriter, self).setup(trainer)
self._dir = logger.LOG_DIR
self._fname = os.path.join(self._dir, 'stat.json')
if os.path.isfile(self._fname):
# TODO make a backup first?
logger.info("Found existing JSON at {}, will append to it.".format(self._fname))
with open(self._fname) as f:
self._stats = json.load(f)
assert isinstance(self._stats, list), type(self._stats)
else:
self._stats = []
self._stat_now = {}
self._last_gs = -1
def put(self, name, val):
gs = self._trainer.global_step
if gs != self._last_gs:
self._push()
self._last_gs = gs
self._stat_now['epoch_num'] = self._trainer.epoch_num
self._stat_now['global_step'] = gs
self._stat_now[name] = float(val) # TODO will fail for non-numeric
def _push(self):
""" Note that this method is idempotent"""
if len(self._stat_now):
self._stats.append(self._stat_now)
self._stat_now = {}
self._write_stat()
def _write_stat(self):
tmp_filename = self._fname + '.tmp'
try:
with open(tmp_filename, 'w') as f:
json.dump(self._stats, f)
os.rename(tmp_filename, self._fname)
except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!")
def flush(self):
self._push()
# TODO print interval
class ScalarPrinter(TrainingMonitor):
def __init__(self):
self._whitelist = None
self._blacklist = set([])
def setup(self, _):
self._dic = {}
def put(self, name, val):
self._dic[name] = float(val)
def _print_stat(self):
for k, v in sorted(self._dic.items(), key=operator.itemgetter(0)):
if self._whitelist is None or k in self._whitelist:
if k not in self._blacklist:
logger.info('{}: {:.5g}'.format(k, v))
def flush(self):
self._print_stat()
self._dic = {}
class ScalarHistory(TrainingMonitor):
def setup(self, _):
self._dic = defaultdict(list)
def put(self, name, val):
self._dic[name].append(float(val))
def get_latest(self, name):
hist = self._dic[name]
if len(hist) == 0:
raise KeyError("Invalid key: {}".format(name))
else:
return hist[-1]
def get_history(self, name):
return self._dic[name]
......@@ -237,7 +237,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
if self.config.nr_tower > 1:
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.add_scalar_summary(
self.monitors.put(
'async_global_step', async_step_total_cnt)
except:
logger.exception("Cannot log async_global_step")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment