Commit 0f4eda16 authored by Yuxin Wu's avatar Yuxin Wu

use Monitors as the backend for all the summaries and stats

parent 61f14083
...@@ -146,12 +146,12 @@ def get_config(): ...@@ -146,12 +146,12 @@ def get_config():
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
CallbackFactory( CallbackFactory(
trigger_epoch=lambda self: trigger_epoch=lambda self:
[self.trainer.add_scalar_summary( [self.trainer.monitors.put(
'validation_perplexity', 'validation_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN)), np.exp(self.trainer.monitors.get_latest('validation_cost') / SEQ_LEN)),
self.trainer.add_scalar_summary( self.trainer.monitors.put(
'test_perplexity', 'test_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('test_cost') / SEQ_LEN))] np.exp(self.trainer.monitors.get_latest('test_cost') / SEQ_LEN))]
), ),
], ],
max_epoch=70, max_epoch=70,
......
...@@ -152,7 +152,6 @@ def get_config(model, algorithm_name): ...@@ -152,7 +152,6 @@ def get_config(model, algorithm_name):
MovingAverageSummary(), MovingAverageSummary(),
ProgressBar(extra_display), ProgressBar(extra_display),
MergeAllSummaries(), MergeAllSummaries(),
StatPrinter()
], ],
max_epoch=20, max_epoch=20,
) )
......
...@@ -63,7 +63,7 @@ def summary_inferencer(trainer, infs): ...@@ -63,7 +63,7 @@ def summary_inferencer(trainer, infs):
except: except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__)) logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue continue
trainer.add_scalar_summary(k, v) trainer.monitors.put(k, v)
class InferenceRunner(Triggerable): class InferenceRunner(Triggerable):
......
...@@ -318,8 +318,7 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -318,8 +318,7 @@ class StatMonitorParamSetter(HyperParamSetter):
self.last_changed_epoch = 0 self.last_changed_epoch = 0
def _get_value_to_set(self): def _get_value_to_set(self):
holder = self.trainer.stat_holder hist = self.trainer.monitors.get_history(self.stat_name)
hist = holder.get_stat_history(self.stat_name)
if len(hist) < self.last_k + 1 or \ if len(hist) < self.last_k + 1 or \
self.epoch_num - self.last_changed_epoch < self.last_k: self.epoch_num - self.last_changed_epoch < self.last_k:
return None return None
......
...@@ -98,7 +98,7 @@ class MinSaver(Triggerable): ...@@ -98,7 +98,7 @@ class MinSaver(Triggerable):
def _get_stat(self): def _get_stat(self):
try: try:
v = self.trainer.stat_holder.get_stat_now(self.monitor_stat) v = self.trainer.monitors.get_latest(self.monitor_stat)
except KeyError: except KeyError:
v = None v = None
return v return v
......
...@@ -3,148 +3,22 @@ ...@@ -3,148 +3,22 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
import operator
import json
from .base import Triggerable from .base import Triggerable
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
__all__ = ['StatHolder', 'StatPrinter', 'SendStat'] __all__ = ['StatPrinter', 'SendStat']
class StatHolder(object):
"""
A holder to keep all statistics aside from tensorflow events.
"""
def __init__(self, log_dir):
"""
Args:
log_dir(str): directory to save the stats.
"""
self.set_print_tag([])
self.blacklist_tag = set()
self.stat_now = {}
self.log_dir = log_dir
self.filename = os.path.join(log_dir, 'stat.json')
if os.path.isfile(self.filename):
# TODO make a backup first?
logger.info("Found stats at {}, will append to it.".format(self.filename))
with open(self.filename) as f:
self.stat_history = json.load(f)
else:
self.stat_history = []
# global step of the current list of stat
self._current_gs = -1
def add_stat(self, k, v, global_step, epoch_num):
"""
Add a stat.
"""
if global_step != self._current_gs:
self._push()
self._current_gs = global_step
self.stat_now['epoch_num'] = epoch_num
self.stat_now['global_step'] = global_step
self.stat_now[k] = float(v)
def set_print_tag(self, print_tag):
"""
Set name of stats to print.
Args:
print_tag: a collection of string.
"""
self.print_tag = None if print_tag is None else set(print_tag)
def add_blacklist_tag(self, blacklist_tag):
""" Disable printing for some tags
Args:
blacklist_tag: a collection of string.
"""
self.blacklist_tag |= set(blacklist_tag)
def get_stat_now(self, key):
"""
Return the value of a stat in the current epoch.
Raises:
KeyError if the key hasn't been added in this epoch.
"""
return self.stat_now[key]
def get_stat_history(self, key):
"""
Returns:
list: all history of a stat. Empty if there is not history of this name.
"""
ret = []
for h in self.stat_history:
v = h.get(key, None)
if v is not None:
ret.append(v)
v = self.stat_now.get(key, None)
if v is not None:
ret.append(v)
return ret
def finalize(self):
"""
Print and write stats to disk.
This method is idempotent.
"""
self._print_stat()
self._push()
def _push(self):
""" Note that this method is idempotent"""
if len(self.stat_now):
self.stat_history.append(self.stat_now)
self.stat_now = {}
self._write_stat()
def _print_stat(self):
for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)):
if self.print_tag is None or k in self.print_tag:
if k not in self.blacklist_tag:
logger.info('{}: {:.5g}'.format(k, v))
def _write_stat(self):
tmp_filename = self.filename + '.tmp'
try:
with open(tmp_filename, 'w') as f:
json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename)
except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!")
class StatPrinter(Triggerable): class StatPrinter(Triggerable):
"""
A callback to control what stats to print. Enable by default to print
everything in trainer.stat_holder.
"""
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
""" log_deprecated("StatPrinter",
Args: "No need to add StatPrinter to callbacks anymore!",
print_tag: a list of stat names to print. "2017-03-26")
If None, will print all scalar tags.
"""
self.print_tag = print_tag
def _before_train(self):
self._stat_holder = self.trainer.stat_holder
self._stat_holder.set_print_tag(self.print_tag)
self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
def _trigger(self):
self._stat_holder.finalize()
# TODO make it into monitor?
class SendStat(Triggerable): class SendStat(Triggerable):
""" """
Execute a command with some specific stats. Execute a command with some specific stats.
...@@ -173,8 +47,8 @@ class SendStat(Triggerable): ...@@ -173,8 +47,8 @@ class SendStat(Triggerable):
self.stats = stats self.stats = stats
def _trigger(self): def _trigger(self):
holder = self.trainer.stat_holder M = self.trainer.monitors
v = {k: holder.get_stat_now(k) for k in self.stats} v = {k: M.get_latest(k) for k in self.stats}
cmd = self.command.format(**v) cmd = self.command.format(**v)
ret = os.system(cmd) ret = os.system(cmd)
if ret != 0: if ret != 0:
......
...@@ -68,9 +68,9 @@ class MergeAllSummaries(Callback): ...@@ -68,9 +68,9 @@ class MergeAllSummaries(Callback):
summary = run_values.results summary = run_values.results
if summary is None: if summary is None:
return return
self.trainer.add_summary(summary) self.trainer.monitors.put_summary(summary)
def _trigger_epoch(self): def _trigger_epoch(self):
if self._run_alone: if self._run_alone:
summary = self.summary_op.eval() summary = self.summary_op.eval()
self.trainer.add_summary(summary) self.trainer.monitors.put_summary(summary)
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re
import time import time
import weakref import weakref
import six import six
...@@ -15,12 +14,12 @@ from tensorflow.python.training.monitored_session \ ...@@ -15,12 +14,12 @@ from tensorflow.python.training.monitored_session \
from .predict import PredictorFactory from .predict import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
from .monitor import Monitors, TrainingMonitor
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated, log_deprecated from ..utils.develop import deprecated, log_deprecated
from ..callbacks import StatHolder, Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_scalar_summary
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer'] __all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
...@@ -40,9 +39,7 @@ class Trainer(object): ...@@ -40,9 +39,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer. config (TrainConfig): the config used in this trainer.
model (ModelDesc) model (ModelDesc)
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
monitors (Monitors): the monitors
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
epoch_num (int): the number of epochs that have finished. epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch. local_step (int): the number of steps that have finished in the current epoch.
...@@ -65,6 +62,8 @@ class Trainer(object): ...@@ -65,6 +62,8 @@ class Trainer(object):
for cb in self.config.callbacks: for cb in self.config.callbacks:
self.register_callback(cb) self.register_callback(cb)
self.monitors = config.monitors
def register_callback(self, cb): def register_callback(self, cb):
""" """
Use this method before :meth:`Trainer._setup` finishes, Use this method before :meth:`Trainer._setup` finishes,
...@@ -78,6 +77,12 @@ class Trainer(object): ...@@ -78,6 +77,12 @@ class Trainer(object):
"Cannot register more callbacks after trainer was setup!" "Cannot register more callbacks after trainer was setup!"
self._callbacks.append(cb) self._callbacks.append(cb)
def register_monitor(self, mon):
assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
self.monitors.append(mon)
def train(self): def train(self):
""" Start training """ """ Start training """
self.setup() self.setup()
...@@ -88,48 +93,9 @@ class Trainer(object): ...@@ -88,48 +93,9 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration". """ Abstract method: run one iteration. Subclass should define what is "iteration".
""" """
def trigger_epoch(self):
"""
Called after each epoch.
"""
# trigger subclass
self._trigger_epoch()
# trigger callbacks
self._callbacks.trigger_epoch()
self.summary_writer.flush()
def _trigger_epoch(self): def _trigger_epoch(self):
pass pass
def add_summary(self, summary):
"""
Add summary to ``self.summary_writer``, and also
add scalar summary to ``self.stat_holder``.
Args:
summary (tf.Summary or str): a summary object, or a str which will
be interpreted as a serialized tf.Summary protobuf.
"""
if isinstance(summary, six.binary_type):
summary = tf.Summary.FromString(summary)
assert isinstance(summary, tf.Summary), type(summary)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(
val.tag, val.simple_value,
self.global_step, self.epoch_num)
self.summary_writer.add_summary(summary, get_global_step_value())
def add_scalar_summary(self, name, val):
"""
Add a scalar summary to both TF events file and StatHolder.
"""
self.add_summary(create_scalar_summary(name, val))
def setup(self): def setup(self):
""" """
Setup the trainer and be ready for the main loop. Setup the trainer and be ready for the main loop.
...@@ -141,10 +107,9 @@ class Trainer(object): ...@@ -141,10 +107,9 @@ class Trainer(object):
describe_model() describe_model()
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup summaries ...") logger.info("Setup monitors ...")
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph()) self.monitors = Monitors(self.monitors)
# create an empty StatHolder self.monitors.setup(weakref.proxy(self))
self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
...@@ -202,7 +167,10 @@ class Trainer(object): ...@@ -202,7 +167,10 @@ class Trainer(object):
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self.epoch_num, self.global_step, time.time() - start_time))
self.trigger_epoch() # trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._trigger_epoch()
self._callbacks.trigger_epoch()
self.monitors.flush()
except StopTraining: except StopTraining:
logger.info("Training was stopped.") logger.info("Training was stopped.")
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -211,9 +179,10 @@ class Trainer(object): ...@@ -211,9 +179,10 @@ class Trainer(object):
raise raise
finally: finally:
self._callbacks.after_train() self._callbacks.after_train()
self.summary_writer.close() self.monitors.close()
self.monitored_sess.close() self.monitored_sess.close()
# Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0): def get_predictor(self, input_names, output_names, tower=0):
""" """
Args: Args:
......
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, Callbacks, MovingAverageSummary,
StatPrinter, ProgressBar, MergeAllSummaries) ProgressBar, MergeAllSummaries)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -15,6 +15,7 @@ from ..tfutils import (JustCurrentSession, ...@@ -15,6 +15,7 @@ from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.optimizer import apply_grad_processors from ..tfutils.optimizer import apply_grad_processors
from .input_data import InputData from .input_data import InputData
from .monitor import TFSummaryWriter, JSONWriter, ScalarPrinter
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
...@@ -24,11 +25,12 @@ class TrainConfig(object): ...@@ -24,11 +25,12 @@ class TrainConfig(object):
Config for trainer. Config for trainer.
""" """
def __init__(self, dataflow=None, data=None, def __init__(self,
dataflow=None, data=None,
model=None, model=None,
callbacks=None, extra_callbacks=None, callbacks=None, extra_callbacks=None,
session_config=get_default_sess_config(), monitors=None,
session_init=None, session_config=get_default_sess_config(), session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999, starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0], nr_tower=1, tower=None, predict_tower=[0],
**kwargs): **kwargs):
...@@ -41,10 +43,10 @@ class TrainConfig(object): ...@@ -41,10 +43,10 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), StatPrinter()]``. The list of ``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
Note that ``StatPrinter`` should be the last one to be able to print monitors (list): a list of :class:`TrainingMonitor`.
stats generated by other callbacks. Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
session_config (tf.ConfigProto): the config used to instantiate the session. session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session. session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch. starting_epoch (int): The index of the first epoch.
...@@ -86,11 +88,14 @@ class TrainConfig(object): ...@@ -86,11 +88,14 @@ class TrainConfig(object):
extra_callbacks = [ extra_callbacks = [
MovingAverageSummary(), MovingAverageSummary(),
ProgressBar(), ProgressBar(),
MergeAllSummaries(), MergeAllSummaries()]
StatPrinter()]
self.callbacks = callbacks + extra_callbacks self.callbacks = callbacks + extra_callbacks
assert_type(self.callbacks, list) assert_type(self.callbacks, list)
if monitors is None:
monitors = [TFSummaryWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors
self.model = model self.model = model
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: monitor.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import operator
from collections import defaultdict
import six
import json
import re
import tensorflow as tf
from ..utils import logger
__all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter']
class TrainingMonitor(object):
def setup(self, trainer):
self._trainer = trainer
def put_summary(self, summary):
pass
def put(self, name, val): # TODO split by types?
pass
def flush(self):
pass
def close(self):
pass
class Monitors(TrainingMonitor):
def __init__(self, monitors):
# TODO filter by names
self._scalar_history = ScalarHistory()
self._monitors = monitors + [self._scalar_history]
def setup(self, trainer):
for m in self._monitors:
m.setup(trainer)
def flush(self):
for m in self._monitors:
m.flush()
def close(self):
for m in self._monitors:
m.close()
def _dispatch_put_summary(self, summary):
for m in self._monitors:
m.put_summary(summary)
def _dispatch_put(self, name, val):
for m in self._monitors:
m.put(name, val)
def put_summary(self, summary):
if isinstance(summary, six.binary_type):
summary = tf.Summary.FromString(summary)
assert isinstance(summary, tf.Summary), type(summary)
self._dispatch_put_summary(summary)
# TODO other types
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self._dispatch_put(val.tag, val.simple_value)
def put(self, name, val):
val = float(val) # TODO only support numeric for now
self._dispatch_put(name, val)
s = tf.Summary()
s.value.add(tag=name, simple_value=val)
self._dispatch_put_summary(s)
def get_latest(self, name):
return self._scalar_history.get_latest(name)
def get_history(self, name):
return self._scalar_history.get_history(name)
class TFSummaryWriter(TrainingMonitor):
def setup(self, trainer):
super(TFSummaryWriter, self).setup(trainer)
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
def put_summary(self, summary):
self._writer.add_summary(summary, self._trainer.global_step)
def flush(self):
self._writer.flush()
def close(self):
self._writer.close()
class JSONWriter(TrainingMonitor):
def setup(self, trainer):
super(JSONWriter, self).setup(trainer)
self._dir = logger.LOG_DIR
self._fname = os.path.join(self._dir, 'stat.json')
if os.path.isfile(self._fname):
# TODO make a backup first?
logger.info("Found existing JSON at {}, will append to it.".format(self._fname))
with open(self._fname) as f:
self._stats = json.load(f)
assert isinstance(self._stats, list), type(self._stats)
else:
self._stats = []
self._stat_now = {}
self._last_gs = -1
def put(self, name, val):
gs = self._trainer.global_step
if gs != self._last_gs:
self._push()
self._last_gs = gs
self._stat_now['epoch_num'] = self._trainer.epoch_num
self._stat_now['global_step'] = gs
self._stat_now[name] = float(val) # TODO will fail for non-numeric
def _push(self):
""" Note that this method is idempotent"""
if len(self._stat_now):
self._stats.append(self._stat_now)
self._stat_now = {}
self._write_stat()
def _write_stat(self):
tmp_filename = self._fname + '.tmp'
try:
with open(tmp_filename, 'w') as f:
json.dump(self._stats, f)
os.rename(tmp_filename, self._fname)
except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!")
def flush(self):
self._push()
# TODO print interval
class ScalarPrinter(TrainingMonitor):
def __init__(self):
self._whitelist = None
self._blacklist = set([])
def setup(self, _):
self._dic = {}
def put(self, name, val):
self._dic[name] = float(val)
def _print_stat(self):
for k, v in sorted(self._dic.items(), key=operator.itemgetter(0)):
if self._whitelist is None or k in self._whitelist:
if k not in self._blacklist:
logger.info('{}: {:.5g}'.format(k, v))
def flush(self):
self._print_stat()
self._dic = {}
class ScalarHistory(TrainingMonitor):
def setup(self, _):
self._dic = defaultdict(list)
def put(self, name, val):
self._dic[name].append(float(val))
def get_latest(self, name):
hist = self._dic[name]
if len(hist) == 0:
raise KeyError("Invalid key: {}".format(name))
else:
return hist[-1]
def get_history(self, name):
return self._dic[name]
...@@ -237,7 +237,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -237,7 +237,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
if self.config.nr_tower > 1: if self.config.nr_tower > 1:
async_step_total_cnt = int(re.findall( async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0]) '[0-9]+', self.async_step_counter.__str__())[0])
self.add_scalar_summary( self.monitors.put(
'async_global_step', async_step_total_cnt) 'async_global_step', async_step_total_cnt)
except: except:
logger.exception("Cannot log async_global_step") logger.exception("Cannot log async_global_step")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment