Commit b5f8c73a authored by Yuxin Wu's avatar Yuxin Wu

sphinx doc for callbacks

parent b5acbf3a
...@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`. ...@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[google drive](https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ). [google drive](https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ).
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications. They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The binary-weight 4-bit-activation ResNet-18 model has 59.2% top-1 validation error. The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation error.
Alternative link to this page: [http://dorefa.net](http://dorefa.net) Alternative link to this page: [http://dorefa.net](http://dorefa.net)
......
...@@ -27,7 +27,8 @@ class Callback(object): ...@@ -27,7 +27,8 @@ class Callback(object):
Called before finalizing the graph. Called before finalizing the graph.
Use this callback to setup some ops used in the callback. Use this callback to setup some ops used in the callback.
:param trainer: :class:`train.Trainer` instance Args:
trainer(Trainer): the trainer which calls the callback
""" """
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
...@@ -59,7 +60,7 @@ class Callback(object): ...@@ -59,7 +60,7 @@ class Callback(object):
""" """
Triggered after every epoch. Triggered after every epoch.
In this function, self.epoch_num would be the number of epoch finished. In this function, ``self.epoch_num`` would be the number of epoch finished.
""" """
self.epoch_num += 1 self.epoch_num += 1
self._trigger_epoch() self._trigger_epoch()
...@@ -72,8 +73,15 @@ class Callback(object): ...@@ -72,8 +73,15 @@ class Callback(object):
class ProxyCallback(Callback): class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
"""
def __init__(self, cb): def __init__(self, cb):
"""
Args:
cb(Callback): the underlying callback
"""
self.cb = cb self.cb = cb
def _before_train(self): def _before_train(self):
...@@ -94,14 +102,20 @@ class ProxyCallback(Callback): ...@@ -94,14 +102,20 @@ class ProxyCallback(Callback):
class PeriodicCallback(ProxyCallback): class PeriodicCallback(ProxyCallback):
""" """
A callback to be triggered after every `period` epochs. Wrap a callback so that it is triggered after every ``period`` epochs.
Doesn't work for trigger_step
Doesn't work for ``trigger_step``.
""" """
def __init__(self, cb, period): def __init__(self, cb, period):
""" """
:param cb: a `Callback` Args:
:param period: int cb(Callback): the callback to be triggered periodically
period(int): the period
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
""" """
super(PeriodicCallback, self).__init__(cb) super(PeriodicCallback, self).__init__(cb)
self.period = int(period) self.period = int(period)
......
...@@ -11,15 +11,19 @@ __all__ = ['StartProcOrThread'] ...@@ -11,15 +11,19 @@ __all__ = ['StartProcOrThread']
class StartProcOrThread(Callback): class StartProcOrThread(Callback):
"""
Start some threads or processes before training.
"""
def __init__(self, procs_threads): def __init__(self, startable):
""" """
Start extra threads and processes before training Args:
:param procs_threads: list of processes or threads startable(list): list of processes or threads which have ``start()`` method.
Can also be a single instance of process of thread.
""" """
if not isinstance(procs_threads, list): if not isinstance(startable, list):
procs_threads = [procs_threads] startable = [startable]
self._procs_threads = procs_threads self._procs_threads = startable
def _before_train(self): def _before_train(self):
logger.info("Starting " + logger.info("Starting " +
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dispatcher.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..tfutils.common import get_op_tensor_name
__all__ = ['OutputTensorDispatcer']
class OutputTensorDispatcer(object):
def __init__(self):
self._names = []
self._idxs = []
def add_entry(self, names):
v = []
for n in names:
tensorname = get_op_tensor_name(n)[1]
if tensorname in self._names:
v.append(self._names.index(tensorname))
else:
self._names.append(tensorname)
v.append(len(self._names) - 1)
self._idxs.append(v)
def get_all_names(self):
return self._names
def get_idx_for_each_entry(self):
return self._idxs
...@@ -8,26 +8,27 @@ import numpy as np ...@@ -8,26 +8,27 @@ import numpy as np
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils import get_op_var_name from ..tfutils import get_op_tensor_name
__all__ = ['DumpParamAsImage'] __all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback): class DumpParamAsImage(Callback):
""" """
Dump a variable to image(s) after every epoch to logger.LOG_DIR. Dump a variable to image(s) to ``logger.LOG_DIR`` after every epoch.
""" """
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False): def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
""" """
:param var_name: the name of the variable. Args:
:param prefix: the filename prefix for saved images. Default is the op name. var_name (str): the name of the variable.
:param map_func: map the value of the variable to an image or list of prefix (str): the filename prefix for saved images. Defaults to the Op name.
images of shape [h, w] or [h, w, c]. If None, will use identity map_func: map the value of the variable to an image or list of
:param scale: a multiplier on pixel values, applied after map_func. default to 255 images of shape [h, w] or [h, w, c]. If None, will use identity.
:param clip: whether to clip the result to [0, 255] scale (float): a multiplier on pixel values, applied after map_func.
clip (bool): whether to clip the result to [0, 255].
""" """
op_name, self.var_name = get_op_var_name(var_name) op_name, self.var_name = get_op_tensor_name(var_name)
self.func = map_func self.func = map_func
if prefix is None: if prefix is None:
self.prefix = op_name self.prefix = op_name
...@@ -45,7 +46,7 @@ class DumpParamAsImage(Callback): ...@@ -45,7 +46,7 @@ class DumpParamAsImage(Callback):
val = self.trainer.sess.run(self.var) val = self.trainer.sess.run(self.var)
if self.func is not None: if self.func is not None:
val = self.func(val) val = self.func(val)
if isinstance(val, list): if isinstance(val, list) or val.ndim == 4:
for idx, im in enumerate(val): for idx, im in enumerate(val):
self._dump_image(im, idx) self._dump_image(im, idx)
else: else:
......
...@@ -11,13 +11,19 @@ __all__ = ['RunOp'] ...@@ -11,13 +11,19 @@ __all__ = ['RunOp']
class RunOp(Callback): class RunOp(Callback):
""" Run an op periodically""" """ Run an Op. """
def __init__(self, setup_func, run_before=True, run_epoch=True): def __init__(self, setup_func, run_before=True, run_epoch=True):
""" """
:param setup_func: a function that returns the op in the graph Args:
:param run_before: run the op before training setup_func: a function that returns the Op in the graph
:param run_epoch: run the op on every epoch trigger run_before (bool): run the Op before training
run_epoch (bool): run the Op on every epoch trigger
Examples:
The `DQN Example
<https://github.com/ppwwyyxx/tensorpack/blob/master/examples/Atari2600/DQN.py#L182>`_
uses this callback to update target network.
""" """
self.setup_func = setup_func self.setup_func = setup_func
self.run_before = run_before self.run_before = run_before
...@@ -25,7 +31,6 @@ class RunOp(Callback): ...@@ -25,7 +31,6 @@ class RunOp(Callback):
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
# self._op_name = self._op.name
def _before_train(self): def _before_train(self):
if self.run_before: if self.run_before:
......
...@@ -44,12 +44,14 @@ class CallbackTimeLogger(object): ...@@ -44,12 +44,14 @@ class CallbackTimeLogger(object):
class Callbacks(Callback): class Callbacks(Callback):
""" """
A container to hold all callbacks, and execute them in the right order and proper session. A container to hold all callbacks, and execute them in the right order
(e.g. :class:`StatPrinter` will be executed at last).
""" """
def __init__(self, cbs): def __init__(self, cbs):
""" """
:param cbs: a list of `Callbacks` Args:
cbs(list): a list of :class:`Callback` instances.
""" """
# check type # check type
for cb in cbs: for cb in cbs:
......
...@@ -12,12 +12,13 @@ from ..utils import logger ...@@ -12,12 +12,13 @@ from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_var_name from ..tfutils import get_op_var_name
__all__ = ['ClassificationError', __all__ = ['ScalarStats', 'Inferencer',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats'] 'ClassificationError', 'BinaryClassificationStats']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Inferencer(object): class Inferencer(object):
""" Base class of Inferencer. To be used with :class:`InferenceRunner`. """
def before_inference(self): def before_inference(self):
""" """
...@@ -30,7 +31,11 @@ class Inferencer(object): ...@@ -30,7 +31,11 @@ class Inferencer(object):
def datapoint(self, output): def datapoint(self, output):
""" """
Called after complete running every data point Called after each new datapoint finished the forward inference.
Args:
output(list): list of output this inferencer needs. Has the same
length as ``self.get_output_tensors()``.
""" """
self._datapoint(output) self._datapoint(output)
...@@ -41,8 +46,8 @@ class Inferencer(object): ...@@ -41,8 +46,8 @@ class Inferencer(object):
def after_inference(self): def after_inference(self):
""" """
Called after a round of inference ends. Called after a round of inference ends.
Returns a dict of statistics which will be logged by the InferenceRunner. Returns a dict of statistics which will be logged by the :class:`InferenceRunner`.
The inferencer needs to handle other kind of logging by their own. The inferencer needs to handle other type of logging by itself, if there is any.
""" """
return self._after_inference() return self._after_inference()
...@@ -51,7 +56,7 @@ class Inferencer(object): ...@@ -51,7 +56,7 @@ class Inferencer(object):
def get_output_tensors(self): def get_output_tensors(self):
""" """
Return a list of tensor names needed for this inference Return a list of tensor names this inferencer needed.
""" """
return self._get_output_tensors() return self._get_output_tensors()
...@@ -62,15 +67,16 @@ class Inferencer(object): ...@@ -62,15 +67,16 @@ class Inferencer(object):
class ScalarStats(Inferencer): class ScalarStats(Inferencer):
""" """
Write some scalar tensor to both stat and summary. Statistics of some scalar tensor.
The output of the given Ops must be a scalar. The value will be averaged over all given datapoints.
The value will be averaged over all data points in the inference dataflow.
""" """
def __init__(self, names_to_print, prefix='validation'): def __init__(self, names_to_print, prefix='validation'):
""" """
:param names_to_print: list of names of tensors, or just a name Args:
:param prefix: an optional prefix for logging names_to_print(list or str): list of names or just one name. The
corresponding tensors have to be scalar.
prefix(str): a prefix for logging
""" """
if not isinstance(names_to_print, list): if not isinstance(names_to_print, list):
self.names = [names_to_print] self.names = [names_to_print]
...@@ -85,6 +91,8 @@ class ScalarStats(Inferencer): ...@@ -85,6 +91,8 @@ class ScalarStats(Inferencer):
self.stats = [] self.stats = []
def _datapoint(self, output): def _datapoint(self, output):
for o in output:
assert isinstance(o, (float, np.float32)), type(o)
self.stats.append(output) self.stats.append(output)
def _after_inference(self): def _after_inference(self):
...@@ -101,24 +109,27 @@ class ScalarStats(Inferencer): ...@@ -101,24 +109,27 @@ class ScalarStats(Inferencer):
class ClassificationError(Inferencer): class ClassificationError(Inferencer):
""" """
Compute classification error in batch mode, from a `wrong` variable Compute classification error in batch mode, from a ``wrong`` tensor.
The `wrong` tensor is supposed to be an 0/1 integer vector containing The ``wrong`` tensor is supposed to be an binary vector containing
whether each sample in the batch is incorrectly classified. whether each sample in the batch is *incorrectly* classified.
You can use `tf.nn.in_top_k` to produce this vector record top-k error as well. You can use ``tf.nn.in_top_k`` to produce this vector.
This callback produce the "true" error, This Inferencer produces the "true" error,
taking account of the fact that batches might not have the same size in taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size). testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch. Therefore the result can be different from averaging the error rate of each batch.
""" """
def __init__(self, wrong_var_name='incorrect_vector', summary_name='val_error'): def __init__(self, wrong_tensor_name='incorrect_vector', summary_name='val_error'):
""" """
:param wrong_var_name: name of the `wrong` variable Args:
:param summary_name: the name for logging wrong_tensor_name(str): name of the ``wrong`` tensor.
The default is the same as the default output name of
:meth:`prediction_incorrect`.
summary_name(str): the name for logging.
""" """
self.wrong_var_name = wrong_var_name self.wrong_var_name = wrong_tensor_name
self.summary_name = summary_name self.summary_name = summary_name
def _get_output_tensors(self): def _get_output_tensors(self):
...@@ -144,21 +155,23 @@ class ClassificationError(Inferencer): ...@@ -144,21 +155,23 @@ class ClassificationError(Inferencer):
class BinaryClassificationStats(Inferencer): class BinaryClassificationStats(Inferencer):
""" Compute precision/recall in binary classification, given the """
Compute precision / recall in binary classification, given the
prediction vector and the label vector. prediction vector and the label vector.
""" """
def __init__(self, pred_var_name, label_var_name, summary_prefix='val'): def __init__(self, pred_tensor_name, label_tensor_name, summary_prefix='val'):
""" """
:param pred_var_name: name of the 0/1 prediction tensor. Args:
:param label_var_name: name of the 0/1 label tensor. pred_tensor_name(str): name of the 0/1 prediction tensor.
label_tensor_name(str): name of the 0/1 label tensor.
""" """
self.pred_var_name = pred_var_name self.pred_tensor_name = pred_tensor_name
self.label_var_name = label_var_name self.label_tensor_name = label_tensor_name
self.prefix = summary_prefix self.prefix = summary_prefix
def _get_output_tensors(self): def _get_output_tensors(self):
return [self.pred_var_name, self.label_var_name] return [self.pred_tensor_name, self.label_tensor_name]
def _before_inference(self): def _before_inference(self):
self.stat = BinaryStatistics() self.stat = BinaryStatistics()
......
...@@ -11,13 +11,36 @@ from six.moves import zip, range ...@@ -11,13 +11,36 @@ from six.moves import zip, range
from ..dataflow import DataFlow from ..dataflow import DataFlow
from .base import Callback from .base import Callback
from .inference import Inferencer from .inference import Inferencer
from .dispatcher import OutputTensorDispatcer
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm
from ..tfutils.common import get_op_tensor_name
from ..train.input_data import FeedfreeInput from ..train.input_data import FeedfreeInput
__all__ = ['InferenceRunner'] __all__ = ['InferenceRunner']
class OutputTensorDispatcer(object):
def __init__(self):
self._names = []
self._idxs = []
def add_entry(self, names):
v = []
for n in names:
tensorname = get_op_tensor_name(n)[1]
if tensorname in self._names:
v.append(self._names.index(tensorname))
else:
self._names.append(tensorname)
v.append(len(self._names) - 1)
self._idxs.append(v)
def get_all_names(self):
return self._names
def get_idx_for_each_entry(self):
return self._idxs
def summary_inferencer(trainer, infs): def summary_inferencer(trainer, infs):
for inf in infs: for inf in infs:
ret = inf.after_inference() ret = inf.after_inference()
...@@ -32,17 +55,19 @@ def summary_inferencer(trainer, infs): ...@@ -32,17 +55,19 @@ def summary_inferencer(trainer, infs):
class InferenceRunner(Callback): class InferenceRunner(Callback):
""" """
A callback that runs different kinds of inferencer. A callback that runs a list of :class:`Inferencer` on some
:class:`DataFlow`.
""" """
IOTensor = namedtuple('IOTensor', ['index', 'isOutput']) _IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None): def __init__(self, ds, infs, input_tensors=None):
""" """
:param ds: inference dataset. a `DataFlow` instance. Args:
:param infs: a list of `Inferencer` instance. ds (DataFlow): the DataFlow to run inferencer on.
:param input_tensor_names: list of tensors to feed the dataflow to. infs (list): a list of `Inferencer` instances.
default to all the input placeholders. input_tensor_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
...@@ -78,7 +103,7 @@ class InferenceRunner(Callback): ...@@ -78,7 +103,7 @@ class InferenceRunner(Callback):
dispatcer.add_entry(inf.get_output_tensors()) dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names() all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor IOTensor = InferenceRunner._IOTensor
self.output_tensors = list(filter( self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names)) lambda x: x not in self.input_tensors, all_names))
......
...@@ -13,40 +13,59 @@ from .base import Callback ...@@ -13,40 +13,59 @@ from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils import get_op_var_name from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter', __all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam',
'HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter', 'ScheduledHyperParamSetter',
'StatMonitorParamSetter', 'HyperParamSetterWithFunc', 'StatMonitorParamSetter', 'HyperParamSetterWithFunc',
'HyperParam', 'GraphVarParam', 'ObjAttrParam'] ]
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class HyperParam(object): class HyperParam(object):
""" Base class for a hyper param""" """ Base class for a hyperparam. """
def setup_graph(self): def setup_graph(self):
""" setup the graph in `setup_graph` callback stage, if necessary""" """ setup the graph in ``setup_graph`` callback stage, if necessary"""
pass pass
@abstractmethod @abstractmethod
def set_value(self, v): def set_value(self, v):
""" define how the value of the param will be set""" """
Set the value of the param.
Args:
v: the value to be set
"""
pass
@abstractmethod
def get_value(self):
"""
Get the value of the param.
"""
pass pass
@property @property
def readable_name(self): def readable_name(self):
""" A name to display""" """ A name to display """
return self._readable_name return self._readable_name
class GraphVarParam(HyperParam): class GraphVarParam(HyperParam):
""" a variable in the graph can be a hyperparam""" """ A variable in the graph (e.g. learning_rate) can be a hyperparam"""
def __init__(self, name, shape=[]): def __init__(self, name, shape=[]):
"""
Args:
name(str): name of the variable.
shape(list): shape of the variable.
"""
self.name = name self.name = name
self.shape = shape self.shape = shape
self._readable_name, self.var_name = get_op_var_name(name) self._readable_name, self.var_name = get_op_var_name(name)
def setup_graph(self): def setup_graph(self):
""" Will setup the assign operator for that variable. """
all_vars = tf.global_variables() all_vars = tf.global_variables()
for v in all_vars: for v in all_vars:
if v.name == self.var_name: if v.name == self.var_name:
...@@ -60,17 +79,24 @@ class GraphVarParam(HyperParam): ...@@ -60,17 +79,24 @@ class GraphVarParam(HyperParam):
self.assign_op = self.var.assign(self.val_holder) self.assign_op = self.var.assign(self.val_holder)
def set_value(self, v): def set_value(self, v):
""" Assign the variable a new value. """
self.assign_op.eval(feed_dict={self.val_holder: v}) self.assign_op.eval(feed_dict={self.val_holder: v})
def get_value(self): def get_value(self):
""" Evaluate the variable. """
return self.var.eval() return self.var.eval()
class ObjAttrParam(HyperParam): class ObjAttrParam(HyperParam):
""" an attribute of an object can be a hyperparam""" """ An attribute of an object can be a hyperparam. """
def __init__(self, obj, attrname, readable_name=None): def __init__(self, obj, attrname, readable_name=None):
""" :param readable_name: default to be attrname.""" """
Args:
obj: the object
attrname (str): the attribute
readable_name(str): The name to display. Defaults to be ``attrname``.
"""
self.obj = obj self.obj = obj
self.attrname = attrname self.attrname = attrname
if readable_name is None: if readable_name is None:
...@@ -87,12 +113,14 @@ class ObjAttrParam(HyperParam): ...@@ -87,12 +113,14 @@ class ObjAttrParam(HyperParam):
class HyperParamSetter(Callback): class HyperParamSetter(Callback):
""" """
Base class to set hyperparameters after every epoch. An abstract base callback to set hyperparameters in every epoch.
""" """
def __init__(self, param): def __init__(self, param):
""" """
:param param: a `HyperParam` instance, or a string (assumed to be a scalar `GraphVarParam`) Args:
param(HyperParam or str): if is a :class:`str`, it is assumed to
be a :class:`GraphVarParam`.
""" """
# if a string, assumed to be a scalar graph variable # if a string, assumed to be a scalar graph variable
if isinstance(param, six.string_types): if isinstance(param, six.string_types):
...@@ -106,7 +134,13 @@ class HyperParamSetter(Callback): ...@@ -106,7 +134,13 @@ class HyperParamSetter(Callback):
def get_value_to_set(self): def get_value_to_set(self):
""" """
:returns: the value to assign to the variable now. Returns:
The value to assign to the variable.
Note:
Subclasses will implemenet the abstract method
:meth:`_get_value_to_set`, which should return a new value to
set, or return None to do nothing.
""" """
ret = self._get_value_to_set() ret = self._get_value_to_set()
if ret is not None and ret != self.last_value: if ret is not None and ret != self.last_value:
...@@ -115,13 +149,17 @@ class HyperParamSetter(Callback): ...@@ -115,13 +149,17 @@ class HyperParamSetter(Callback):
self.last_value = ret self.last_value = ret
return ret return ret
def get_current_value(self):
return self.param.get_value()
@abstractmethod @abstractmethod
def _get_value_to_set(self): def _get_value_to_set(self):
pass pass
def get_current_value(self):
"""
Returns:
The current value of the param.
"""
return self.param.get_value()
def _trigger_epoch(self): def _trigger_epoch(self):
self._set_param() self._set_param()
...@@ -136,14 +174,19 @@ class HyperParamSetter(Callback): ...@@ -136,14 +174,19 @@ class HyperParamSetter(Callback):
class HumanHyperParamSetter(HyperParamSetter): class HumanHyperParamSetter(HyperParamSetter):
""" """
Set hyperparameters by loading the value from a file each time it get called. Set hyperparameter by loading the value from a file each time it get called.
This is useful for manually tuning some parameters (e.g. learning_rate)
without interrupting the training.
""" """
def __init__(self, param, file_name='hyper.txt'): def __init__(self, param, file_name='hyper.txt'):
""" """
:param file_name: a file containing the value of the variable. Args:
Each line in the file is a k:v pair, where k is param: same as in :class:`HyperParamSetter`.
param.readable_name, and v is the value file_name(str): a file containing the value of the variable.
Each line in the file is a k:v pair, where k is
param.readable_name, and v is the value. If the pair is not found,
the param will not be changed.
""" """
super(HumanHyperParamSetter, self).__init__(param) super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name) self.file_name = os.path.join(logger.LOG_DIR, file_name)
...@@ -170,15 +213,25 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -170,15 +213,25 @@ class HumanHyperParamSetter(HyperParamSetter):
class ScheduledHyperParamSetter(HyperParamSetter): class ScheduledHyperParamSetter(HyperParamSetter):
""" """
Set hyperparameters by a predefined schedule. Set hyperparameters by a predefined epoch-based schedule.
""" """
def __init__(self, param, schedule, interp=None): def __init__(self, param, schedule, interp=None):
""" """
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...] Args:
(ep, val) means set the param to "val" after the `ep`th epoch. param: same as in :class:`HyperParamSetter`.
If epoch == 0, the value is set before training. schedule(list): with the format ``[(epoch1, val1), (epoch2, val2),
:param interp: None: no interpolation. 'linear': linear interpolation (epoch3, val3), ...]``.
Each ``(ep, val)`` pair means to set the param
to "val" after the `ep`th epoch.
If ep == 0, the value will be set before training.
interp: None: no interpolation. 'linear': linear interpolation
Example:
.. code-block:: python
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
""" """
schedule = [(int(a), float(b)) for a, b in schedule] schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0)) self.schedule = sorted(schedule, key=operator.itemgetter(0))
...@@ -209,10 +262,20 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -209,10 +262,20 @@ class ScheduledHyperParamSetter(HyperParamSetter):
class HyperParamSetterWithFunc(HyperParamSetter): class HyperParamSetterWithFunc(HyperParamSetter):
""" Set the parameter by a function of epoch num and old value. """
def __init__(self, param, func): def __init__(self, param, func):
"""Set hyperparameter by a func """
new_value = f(epoch_num, old_value) Args:
param: same as in :class:`HyperParamSetter`.
func: ``param`` will be set by ``new_value = func(epoch_num, old_value)``.
Example:
Decrease by a factor of 0.9 every two epochs:
.. code-block:: python
HyperParamSetterWithFunc('learning_rate',
lambda e, x: x * 0.9 if e % 2 == 0 else x)
""" """
super(HyperParamSetterWithFunc, self).__init__(param) super(HyperParamSetterWithFunc, self).__init__(param)
self.f = func self.f = func
...@@ -222,22 +285,32 @@ class HyperParamSetterWithFunc(HyperParamSetter): ...@@ -222,22 +285,32 @@ class HyperParamSetterWithFunc(HyperParamSetter):
class StatMonitorParamSetter(HyperParamSetter): class StatMonitorParamSetter(HyperParamSetter):
"""
Change the param by monitoring the change of a statistic.
Change when it wasn't decreasing/increasing enough.
"""
def __init__(self, param, stat_name, value_func, threshold, def __init__(self, param, stat_name, value_func, threshold,
last_k, reverse=False last_k, reverse=False):
):
""" """
Set hyperparameter by a func, when a specific stat wasn't Args:
decreasing/increasing enough in the last $k$ epochs. param: same as in :class:`HyperParamSetter`.
Change param by `new_value = value_func(old_value)`, stat_name (str): name of the statistics.
if : value_func (float -> float): a function which returns a new value
min(stats) >= stats[0] - threshold, where taking the old value.
stats = [`stat_nam` in latest `last_k` epochs] threshold (float): change threshold.
last_k (int): last k epochs.
reverse (bool): monitor increasing instead of decreasing.
This callback will change param by ``new_value = value_func(old_value)``, when:
``min(stats) >= stats[0] - threshold``, where
``stats = [stat_name in last k epochs]``
Example:
If validation error wasn't decreasing for 5 epochs, anneal the learning rate:
For example, if error wasn't decreasing, anneal the learning rate: .. code-block:: python
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
If reverse==True, use 'increasing' instead of decreasing StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2, 0, 5)
""" """
super(StatMonitorParamSetter, self).__init__(param) super(StatMonitorParamSetter, self).__init__(param)
self.stat_name = stat_name self.stat_name = stat_name
......
...@@ -16,22 +16,19 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] ...@@ -16,22 +16,19 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback): class ModelSaver(Callback):
""" """
Save the model to logger directory. Save the model to ``logger.LOG_DIR`` directory every epoch.
""" """
def __init__(self, keep_recent=10, keep_freq=0.5, def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=None): var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
""" """
:param keep_recent: see `tf.train.Saver` documentation. Args:
:param keep_freq: see `tf.train.Saver` documentation. keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation.
var_collections (str or list): the variable collection (or list of collections) o save.
""" """
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.keep_freq = keep_freq self.keep_freq = keep_freq
if var_collections is None:
try:
var_collections = tf.GraphKeys.GLOBAL_VARIABLES
except:
var_collections = tf.GraphKeys.VARIABLES
if not isinstance(var_collections, list): if not isinstance(var_collections, list):
var_collections = [var_collections] var_collections = [var_collections]
self.var_collections = var_collections self.var_collections = var_collections
...@@ -87,8 +84,25 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name)) ...@@ -87,8 +84,25 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
class MinSaver(Callback): class MinSaver(Callback):
"""
Separately save the model with minimum value of some statistics.
"""
def __init__(self, monitor_stat, reverse=False, filename=None):
"""
Args:
monitor_stat(str): the name of the statistics.
reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``.
def __init__(self, monitor_stat, reverse=True, filename=None): Example:
Save the model with minimum validation error to
"min-val-error.tfmodel" under ``logger.LOG_DIR``:
.. code-block:: python
MinSaver('val-error')
"""
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
self.reverse = reverse self.reverse = reverse
self.filename = filename self.filename = filename
...@@ -128,6 +142,14 @@ class MinSaver(Callback): ...@@ -128,6 +142,14 @@ class MinSaver(Callback):
class MaxSaver(MinSaver): class MaxSaver(MinSaver):
"""
def __init__(self, monitor_stat): Separately save the model with maximum value of some statistics.
super(MaxSaver, self).__init__(monitor_stat, True) """
def __init__(self, monitor_stat, filename=None):
"""
Args:
monitor_stat(str): the name of the statistics.
filename (str): the name for the saved model.
Defaults to ``max-{monitor_stat}.tfmodel``.
"""
super(MaxSaver, self).__init__(monitor_stat, True, filename=filename)
...@@ -20,7 +20,8 @@ class StatHolder(object): ...@@ -20,7 +20,8 @@ class StatHolder(object):
def __init__(self, log_dir): def __init__(self, log_dir):
""" """
:param log_dir: directory to save the stats. Args:
log_dir(str): directory to save the stats.
""" """
self.set_print_tag([]) self.set_print_tag([])
self.blacklist_tag = set() self.blacklist_tag = set()
...@@ -38,19 +39,24 @@ class StatHolder(object): ...@@ -38,19 +39,24 @@ class StatHolder(object):
def add_stat(self, k, v): def add_stat(self, k, v):
""" """
Add a stat. Add a stat.
:param k: name
:param v: value
""" """
self.stat_now[k] = float(v) self.stat_now[k] = float(v)
def set_print_tag(self, print_tag): def set_print_tag(self, print_tag):
""" """
Set name of stats to print. Set name of stats to print.
Args:
print_tag: a collection of string.
""" """
self.print_tag = None if print_tag is None else set(print_tag) self.print_tag = None if print_tag is None else set(print_tag)
def add_blacklist_tag(self, blacklist_tag): def add_blacklist_tag(self, blacklist_tag):
""" Disable printing for some tags """ """ Disable printing for some tags
Args:
blacklist_tag: a collection of string.
"""
self.blacklist_tag |= set(blacklist_tag) self.blacklist_tag |= set(blacklist_tag)
def get_stat_now(self, key): def get_stat_now(self, key):
...@@ -60,6 +66,10 @@ class StatHolder(object): ...@@ -60,6 +66,10 @@ class StatHolder(object):
return self.stat_now[key] return self.stat_now[key]
def get_stat_history(self, key): def get_stat_history(self, key):
"""
Returns:
list: all history of a stat.
"""
ret = [] ret = []
for h in self.stat_history: for h in self.stat_history:
v = h.get(key, None) v = h.get(key, None)
...@@ -97,13 +107,14 @@ class StatHolder(object): ...@@ -97,13 +107,14 @@ class StatHolder(object):
class StatPrinter(Callback): class StatPrinter(Callback):
""" """
Control what stats to print. A callback to control what stats to print. Print everything by default.
""" """
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
""" """
:param print_tag: a list of regex to match scalar summary to print. Args:
If None, will print all scalar tags print_tag: a list of stat names to print.
If None, will print all scalar tags.
""" """
self.print_tag = print_tag self.print_tag = print_tag
...@@ -125,15 +136,25 @@ class StatPrinter(Callback): ...@@ -125,15 +136,25 @@ class StatPrinter(Callback):
class SendStat(Callback): class SendStat(Callback):
""" """
Execute a command with some specific stats. Execute a command with some specific stats.
For example, send the stats to your phone through pushbullet: This is useful for, e.g. building a custom statistics monitor.
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes \
-d type=note -d title="validation error" \
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
""" """
def __init__(self, command, stats): def __init__(self, command, stats):
"""
Args:
command(str): a command to execute. Use format string with stat
names as keys.
stats(list or str): stat name(s) to use.
Example:
Send the stats to your phone through pushbullet:
.. code-block:: python
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes \\
-d type=note -d title="validation error" \\
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
self.command = command self.command = command
if not isinstance(stats, list): if not isinstance(stats, list):
stats = [stats] stats = [stats]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment