Commit b5f8c73a authored by Yuxin Wu's avatar Yuxin Wu

sphinx doc for callbacks

parent b5acbf3a
......@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[google drive](https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ).
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The binary-weight 4-bit-activation ResNet-18 model has 59.2% top-1 validation error.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation error.
Alternative link to this page: [http://dorefa.net](http://dorefa.net)
......
......@@ -27,7 +27,8 @@ class Callback(object):
Called before finalizing the graph.
Use this callback to setup some ops used in the callback.
:param trainer: :class:`train.Trainer` instance
Args:
trainer(Trainer): the trainer which calls the callback
"""
self.trainer = trainer
self.graph = tf.get_default_graph()
......@@ -59,7 +60,7 @@ class Callback(object):
"""
Triggered after every epoch.
In this function, self.epoch_num would be the number of epoch finished.
In this function, ``self.epoch_num`` would be the number of epoch finished.
"""
self.epoch_num += 1
self._trigger_epoch()
......@@ -72,8 +73,15 @@ class Callback(object):
class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
"""
def __init__(self, cb):
"""
Args:
cb(Callback): the underlying callback
"""
self.cb = cb
def _before_train(self):
......@@ -94,14 +102,20 @@ class ProxyCallback(Callback):
class PeriodicCallback(ProxyCallback):
"""
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
Wrap a callback so that it is triggered after every ``period`` epochs.
Doesn't work for ``trigger_step``.
"""
def __init__(self, cb, period):
"""
:param cb: a `Callback`
:param period: int
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
super(PeriodicCallback, self).__init__(cb)
self.period = int(period)
......
......@@ -11,15 +11,19 @@ __all__ = ['StartProcOrThread']
class StartProcOrThread(Callback):
"""
Start some threads or processes before training.
"""
def __init__(self, procs_threads):
def __init__(self, startable):
"""
Start extra threads and processes before training
:param procs_threads: list of processes or threads
Args:
startable(list): list of processes or threads which have ``start()`` method.
Can also be a single instance of process of thread.
"""
if not isinstance(procs_threads, list):
procs_threads = [procs_threads]
self._procs_threads = procs_threads
if not isinstance(startable, list):
startable = [startable]
self._procs_threads = startable
def _before_train(self):
logger.info("Starting " +
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dispatcher.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..tfutils.common import get_op_tensor_name
__all__ = ['OutputTensorDispatcer']
class OutputTensorDispatcer(object):
def __init__(self):
self._names = []
self._idxs = []
def add_entry(self, names):
v = []
for n in names:
tensorname = get_op_tensor_name(n)[1]
if tensorname in self._names:
v.append(self._names.index(tensorname))
else:
self._names.append(tensorname)
v.append(len(self._names) - 1)
self._idxs.append(v)
def get_all_names(self):
return self._names
def get_idx_for_each_entry(self):
return self._idxs
......@@ -8,26 +8,27 @@ import numpy as np
from .base import Callback
from ..utils import logger
from ..tfutils import get_op_var_name
from ..tfutils import get_op_tensor_name
__all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback):
"""
Dump a variable to image(s) after every epoch to logger.LOG_DIR.
Dump a variable to image(s) to ``logger.LOG_DIR`` after every epoch.
"""
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
"""
:param var_name: the name of the variable.
:param prefix: the filename prefix for saved images. Default is the op name.
:param map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity
:param scale: a multiplier on pixel values, applied after map_func. default to 255
:param clip: whether to clip the result to [0, 255]
Args:
var_name (str): the name of the variable.
prefix (str): the filename prefix for saved images. Defaults to the Op name.
map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity.
scale (float): a multiplier on pixel values, applied after map_func.
clip (bool): whether to clip the result to [0, 255].
"""
op_name, self.var_name = get_op_var_name(var_name)
op_name, self.var_name = get_op_tensor_name(var_name)
self.func = map_func
if prefix is None:
self.prefix = op_name
......@@ -45,7 +46,7 @@ class DumpParamAsImage(Callback):
val = self.trainer.sess.run(self.var)
if self.func is not None:
val = self.func(val)
if isinstance(val, list):
if isinstance(val, list) or val.ndim == 4:
for idx, im in enumerate(val):
self._dump_image(im, idx)
else:
......
......@@ -11,13 +11,19 @@ __all__ = ['RunOp']
class RunOp(Callback):
""" Run an op periodically"""
""" Run an Op. """
def __init__(self, setup_func, run_before=True, run_epoch=True):
"""
:param setup_func: a function that returns the op in the graph
:param run_before: run the op before training
:param run_epoch: run the op on every epoch trigger
Args:
setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training
run_epoch (bool): run the Op on every epoch trigger
Examples:
The `DQN Example
<https://github.com/ppwwyyxx/tensorpack/blob/master/examples/Atari2600/DQN.py#L182>`_
uses this callback to update target network.
"""
self.setup_func = setup_func
self.run_before = run_before
......@@ -25,7 +31,6 @@ class RunOp(Callback):
def _setup_graph(self):
self._op = self.setup_func()
# self._op_name = self._op.name
def _before_train(self):
if self.run_before:
......
......@@ -44,12 +44,14 @@ class CallbackTimeLogger(object):
class Callbacks(Callback):
"""
A container to hold all callbacks, and execute them in the right order and proper session.
A container to hold all callbacks, and execute them in the right order
(e.g. :class:`StatPrinter` will be executed at last).
"""
def __init__(self, cbs):
"""
:param cbs: a list of `Callbacks`
Args:
cbs(list): a list of :class:`Callback` instances.
"""
# check type
for cb in cbs:
......
......@@ -12,12 +12,13 @@ from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_var_name
__all__ = ['ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
__all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats']
@six.add_metaclass(ABCMeta)
class Inferencer(object):
""" Base class of Inferencer. To be used with :class:`InferenceRunner`. """
def before_inference(self):
"""
......@@ -30,7 +31,11 @@ class Inferencer(object):
def datapoint(self, output):
"""
Called after complete running every data point
Called after each new datapoint finished the forward inference.
Args:
output(list): list of output this inferencer needs. Has the same
length as ``self.get_output_tensors()``.
"""
self._datapoint(output)
......@@ -41,8 +46,8 @@ class Inferencer(object):
def after_inference(self):
"""
Called after a round of inference ends.
Returns a dict of statistics which will be logged by the InferenceRunner.
The inferencer needs to handle other kind of logging by their own.
Returns a dict of statistics which will be logged by the :class:`InferenceRunner`.
The inferencer needs to handle other type of logging by itself, if there is any.
"""
return self._after_inference()
......@@ -51,7 +56,7 @@ class Inferencer(object):
def get_output_tensors(self):
"""
Return a list of tensor names needed for this inference
Return a list of tensor names this inferencer needed.
"""
return self._get_output_tensors()
......@@ -62,15 +67,16 @@ class Inferencer(object):
class ScalarStats(Inferencer):
"""
Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the inference dataflow.
Statistics of some scalar tensor.
The value will be averaged over all given datapoints.
"""
def __init__(self, names_to_print, prefix='validation'):
"""
:param names_to_print: list of names of tensors, or just a name
:param prefix: an optional prefix for logging
Args:
names_to_print(list or str): list of names or just one name. The
corresponding tensors have to be scalar.
prefix(str): a prefix for logging
"""
if not isinstance(names_to_print, list):
self.names = [names_to_print]
......@@ -85,6 +91,8 @@ class ScalarStats(Inferencer):
self.stats = []
def _datapoint(self, output):
for o in output:
assert isinstance(o, (float, np.float32)), type(o)
self.stats.append(output)
def _after_inference(self):
......@@ -101,24 +109,27 @@ class ScalarStats(Inferencer):
class ClassificationError(Inferencer):
"""
Compute classification error in batch mode, from a `wrong` variable
Compute classification error in batch mode, from a ``wrong`` tensor.
The `wrong` tensor is supposed to be an 0/1 integer vector containing
whether each sample in the batch is incorrectly classified.
You can use `tf.nn.in_top_k` to produce this vector record top-k error as well.
The ``wrong`` tensor is supposed to be an binary vector containing
whether each sample in the batch is *incorrectly* classified.
You can use ``tf.nn.in_top_k`` to produce this vector.
This callback produce the "true" error,
This Inferencer produces the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch.
Therefore the result can be different from averaging the error rate of each batch.
"""
def __init__(self, wrong_var_name='incorrect_vector', summary_name='val_error'):
def __init__(self, wrong_tensor_name='incorrect_vector', summary_name='val_error'):
"""
:param wrong_var_name: name of the `wrong` variable
:param summary_name: the name for logging
Args:
wrong_tensor_name(str): name of the ``wrong`` tensor.
The default is the same as the default output name of
:meth:`prediction_incorrect`.
summary_name(str): the name for logging.
"""
self.wrong_var_name = wrong_var_name
self.wrong_var_name = wrong_tensor_name
self.summary_name = summary_name
def _get_output_tensors(self):
......@@ -144,21 +155,23 @@ class ClassificationError(Inferencer):
class BinaryClassificationStats(Inferencer):
""" Compute precision/recall in binary classification, given the
"""
Compute precision / recall in binary classification, given the
prediction vector and the label vector.
"""
def __init__(self, pred_var_name, label_var_name, summary_prefix='val'):
def __init__(self, pred_tensor_name, label_tensor_name, summary_prefix='val'):
"""
:param pred_var_name: name of the 0/1 prediction tensor.
:param label_var_name: name of the 0/1 label tensor.
Args:
pred_tensor_name(str): name of the 0/1 prediction tensor.
label_tensor_name(str): name of the 0/1 label tensor.
"""
self.pred_var_name = pred_var_name
self.label_var_name = label_var_name
self.pred_tensor_name = pred_tensor_name
self.label_tensor_name = label_tensor_name
self.prefix = summary_prefix
def _get_output_tensors(self):
return [self.pred_var_name, self.label_var_name]
return [self.pred_tensor_name, self.label_tensor_name]
def _before_inference(self):
self.stat = BinaryStatistics()
......
......@@ -11,13 +11,36 @@ from six.moves import zip, range
from ..dataflow import DataFlow
from .base import Callback
from .inference import Inferencer
from .dispatcher import OutputTensorDispatcer
from ..utils import logger, get_tqdm
from ..tfutils.common import get_op_tensor_name
from ..train.input_data import FeedfreeInput
__all__ = ['InferenceRunner']
class OutputTensorDispatcer(object):
def __init__(self):
self._names = []
self._idxs = []
def add_entry(self, names):
v = []
for n in names:
tensorname = get_op_tensor_name(n)[1]
if tensorname in self._names:
v.append(self._names.index(tensorname))
else:
self._names.append(tensorname)
v.append(len(self._names) - 1)
self._idxs.append(v)
def get_all_names(self):
return self._names
def get_idx_for_each_entry(self):
return self._idxs
def summary_inferencer(trainer, infs):
for inf in infs:
ret = inf.after_inference()
......@@ -32,17 +55,19 @@ def summary_inferencer(trainer, infs):
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
A callback that runs a list of :class:`Inferencer` on some
:class:`DataFlow`.
"""
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
_IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param infs: a list of `Inferencer` instance.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
Args:
ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input_tensor_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
"""
assert isinstance(ds, DataFlow), ds
self.ds = ds
......@@ -78,7 +103,7 @@ class InferenceRunner(Callback):
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor
IOTensor = InferenceRunner._IOTensor
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
......
......@@ -13,40 +13,59 @@ from .base import Callback
from ..utils import logger
from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
__all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam',
'HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter',
'StatMonitorParamSetter', 'HyperParamSetterWithFunc',
'HyperParam', 'GraphVarParam', 'ObjAttrParam']
]
@six.add_metaclass(ABCMeta)
class HyperParam(object):
""" Base class for a hyper param"""
""" Base class for a hyperparam. """
def setup_graph(self):
""" setup the graph in `setup_graph` callback stage, if necessary"""
""" setup the graph in ``setup_graph`` callback stage, if necessary"""
pass
@abstractmethod
def set_value(self, v):
""" define how the value of the param will be set"""
"""
Set the value of the param.
Args:
v: the value to be set
"""
pass
@abstractmethod
def get_value(self):
"""
Get the value of the param.
"""
pass
@property
def readable_name(self):
""" A name to display"""
""" A name to display """
return self._readable_name
class GraphVarParam(HyperParam):
""" a variable in the graph can be a hyperparam"""
""" A variable in the graph (e.g. learning_rate) can be a hyperparam"""
def __init__(self, name, shape=[]):
"""
Args:
name(str): name of the variable.
shape(list): shape of the variable.
"""
self.name = name
self.shape = shape
self._readable_name, self.var_name = get_op_var_name(name)
def setup_graph(self):
""" Will setup the assign operator for that variable. """
all_vars = tf.global_variables()
for v in all_vars:
if v.name == self.var_name:
......@@ -60,17 +79,24 @@ class GraphVarParam(HyperParam):
self.assign_op = self.var.assign(self.val_holder)
def set_value(self, v):
""" Assign the variable a new value. """
self.assign_op.eval(feed_dict={self.val_holder: v})
def get_value(self):
""" Evaluate the variable. """
return self.var.eval()
class ObjAttrParam(HyperParam):
""" an attribute of an object can be a hyperparam"""
""" An attribute of an object can be a hyperparam. """
def __init__(self, obj, attrname, readable_name=None):
""" :param readable_name: default to be attrname."""
"""
Args:
obj: the object
attrname (str): the attribute
readable_name(str): The name to display. Defaults to be ``attrname``.
"""
self.obj = obj
self.attrname = attrname
if readable_name is None:
......@@ -87,12 +113,14 @@ class ObjAttrParam(HyperParam):
class HyperParamSetter(Callback):
"""
Base class to set hyperparameters after every epoch.
An abstract base callback to set hyperparameters in every epoch.
"""
def __init__(self, param):
"""
:param param: a `HyperParam` instance, or a string (assumed to be a scalar `GraphVarParam`)
Args:
param(HyperParam or str): if is a :class:`str`, it is assumed to
be a :class:`GraphVarParam`.
"""
# if a string, assumed to be a scalar graph variable
if isinstance(param, six.string_types):
......@@ -106,7 +134,13 @@ class HyperParamSetter(Callback):
def get_value_to_set(self):
"""
:returns: the value to assign to the variable now.
Returns:
The value to assign to the variable.
Note:
Subclasses will implemenet the abstract method
:meth:`_get_value_to_set`, which should return a new value to
set, or return None to do nothing.
"""
ret = self._get_value_to_set()
if ret is not None and ret != self.last_value:
......@@ -115,13 +149,17 @@ class HyperParamSetter(Callback):
self.last_value = ret
return ret
def get_current_value(self):
return self.param.get_value()
@abstractmethod
def _get_value_to_set(self):
pass
def get_current_value(self):
"""
Returns:
The current value of the param.
"""
return self.param.get_value()
def _trigger_epoch(self):
self._set_param()
......@@ -136,14 +174,19 @@ class HyperParamSetter(Callback):
class HumanHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by loading the value from a file each time it get called.
Set hyperparameter by loading the value from a file each time it get called.
This is useful for manually tuning some parameters (e.g. learning_rate)
without interrupting the training.
"""
def __init__(self, param, file_name='hyper.txt'):
"""
:param file_name: a file containing the value of the variable.
Args:
param: same as in :class:`HyperParamSetter`.
file_name(str): a file containing the value of the variable.
Each line in the file is a k:v pair, where k is
param.readable_name, and v is the value
param.readable_name, and v is the value. If the pair is not found,
the param will not be changed.
"""
super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name)
......@@ -170,15 +213,25 @@ class HumanHyperParamSetter(HyperParamSetter):
class ScheduledHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by a predefined schedule.
Set hyperparameters by a predefined epoch-based schedule.
"""
def __init__(self, param, schedule, interp=None):
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
(ep, val) means set the param to "val" after the `ep`th epoch.
If epoch == 0, the value is set before training.
:param interp: None: no interpolation. 'linear': linear interpolation
Args:
param: same as in :class:`HyperParamSetter`.
schedule(list): with the format ``[(epoch1, val1), (epoch2, val2),
(epoch3, val3), ...]``.
Each ``(ep, val)`` pair means to set the param
to "val" after the `ep`th epoch.
If ep == 0, the value will be set before training.
interp: None: no interpolation. 'linear': linear interpolation
Example:
.. code-block:: python
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
"""
schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0))
......@@ -209,10 +262,20 @@ class ScheduledHyperParamSetter(HyperParamSetter):
class HyperParamSetterWithFunc(HyperParamSetter):
""" Set the parameter by a function of epoch num and old value. """
def __init__(self, param, func):
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
"""
Args:
param: same as in :class:`HyperParamSetter`.
func: ``param`` will be set by ``new_value = func(epoch_num, old_value)``.
Example:
Decrease by a factor of 0.9 every two epochs:
.. code-block:: python
HyperParamSetterWithFunc('learning_rate',
lambda e, x: x * 0.9 if e % 2 == 0 else x)
"""
super(HyperParamSetterWithFunc, self).__init__(param)
self.f = func
......@@ -222,22 +285,32 @@ class HyperParamSetterWithFunc(HyperParamSetter):
class StatMonitorParamSetter(HyperParamSetter):
"""
Change the param by monitoring the change of a statistic.
Change when it wasn't decreasing/increasing enough.
"""
def __init__(self, param, stat_name, value_func, threshold,
last_k, reverse=False
):
last_k, reverse=False):
"""
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs.
Change param by `new_value = value_func(old_value)`,
if :
min(stats) >= stats[0] - threshold, where
stats = [`stat_nam` in latest `last_k` epochs]
Args:
param: same as in :class:`HyperParamSetter`.
stat_name (str): name of the statistics.
value_func (float -> float): a function which returns a new value
taking the old value.
threshold (float): change threshold.
last_k (int): last k epochs.
reverse (bool): monitor increasing instead of decreasing.
This callback will change param by ``new_value = value_func(old_value)``, when:
``min(stats) >= stats[0] - threshold``, where
``stats = [stat_name in last k epochs]``
Example:
If validation error wasn't decreasing for 5 epochs, anneal the learning rate:
For example, if error wasn't decreasing, anneal the learning rate:
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
.. code-block:: python
If reverse==True, use 'increasing' instead of decreasing
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2, 0, 5)
"""
super(StatMonitorParamSetter, self).__init__(param)
self.stat_name = stat_name
......
......@@ -16,22 +16,19 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback):
"""
Save the model to logger directory.
Save the model to ``logger.LOG_DIR`` directory every epoch.
"""
def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=None):
var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
Args:
keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation.
var_collections (str or list): the variable collection (or list of collections) o save.
"""
self.keep_recent = keep_recent
self.keep_freq = keep_freq
if var_collections is None:
try:
var_collections = tf.GraphKeys.GLOBAL_VARIABLES
except:
var_collections = tf.GraphKeys.VARIABLES
if not isinstance(var_collections, list):
var_collections = [var_collections]
self.var_collections = var_collections
......@@ -87,8 +84,25 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
class MinSaver(Callback):
"""
Separately save the model with minimum value of some statistics.
"""
def __init__(self, monitor_stat, reverse=False, filename=None):
"""
Args:
monitor_stat(str): the name of the statistics.
reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``.
Example:
Save the model with minimum validation error to
"min-val-error.tfmodel" under ``logger.LOG_DIR``:
.. code-block:: python
def __init__(self, monitor_stat, reverse=True, filename=None):
MinSaver('val-error')
"""
self.monitor_stat = monitor_stat
self.reverse = reverse
self.filename = filename
......@@ -128,6 +142,14 @@ class MinSaver(Callback):
class MaxSaver(MinSaver):
def __init__(self, monitor_stat):
super(MaxSaver, self).__init__(monitor_stat, True)
"""
Separately save the model with maximum value of some statistics.
"""
def __init__(self, monitor_stat, filename=None):
"""
Args:
monitor_stat(str): the name of the statistics.
filename (str): the name for the saved model.
Defaults to ``max-{monitor_stat}.tfmodel``.
"""
super(MaxSaver, self).__init__(monitor_stat, True, filename=filename)
......@@ -20,7 +20,8 @@ class StatHolder(object):
def __init__(self, log_dir):
"""
:param log_dir: directory to save the stats.
Args:
log_dir(str): directory to save the stats.
"""
self.set_print_tag([])
self.blacklist_tag = set()
......@@ -38,19 +39,24 @@ class StatHolder(object):
def add_stat(self, k, v):
"""
Add a stat.
:param k: name
:param v: value
"""
self.stat_now[k] = float(v)
def set_print_tag(self, print_tag):
"""
Set name of stats to print.
Args:
print_tag: a collection of string.
"""
self.print_tag = None if print_tag is None else set(print_tag)
def add_blacklist_tag(self, blacklist_tag):
""" Disable printing for some tags """
""" Disable printing for some tags
Args:
blacklist_tag: a collection of string.
"""
self.blacklist_tag |= set(blacklist_tag)
def get_stat_now(self, key):
......@@ -60,6 +66,10 @@ class StatHolder(object):
return self.stat_now[key]
def get_stat_history(self, key):
"""
Returns:
list: all history of a stat.
"""
ret = []
for h in self.stat_history:
v = h.get(key, None)
......@@ -97,13 +107,14 @@ class StatHolder(object):
class StatPrinter(Callback):
"""
Control what stats to print.
A callback to control what stats to print. Print everything by default.
"""
def __init__(self, print_tag=None):
"""
:param print_tag: a list of regex to match scalar summary to print.
If None, will print all scalar tags
Args:
print_tag: a list of stat names to print.
If None, will print all scalar tags.
"""
self.print_tag = print_tag
......@@ -125,15 +136,25 @@ class StatPrinter(Callback):
class SendStat(Callback):
"""
Execute a command with some specific stats.
For example, send the stats to your phone through pushbullet:
This is useful for, e.g. building a custom statistics monitor.
"""
def __init__(self, command, stats):
"""
Args:
command(str): a command to execute. Use format string with stat
names as keys.
stats(list or str): stat name(s) to use.
Example:
Send the stats to your phone through pushbullet:
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes \
-d type=note -d title="validation error" \
.. code-block:: python
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes \\
-d type=note -d title="validation error" \\
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
def __init__(self, command, stats):
self.command = command
if not isinstance(stats, list):
stats = [stats]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment