Commit a6f88814 authored by Yuxin Wu's avatar Yuxin Wu

docs in callbacks

parent 817bb882
......@@ -10,67 +10,83 @@ from abc import abstractmethod, ABCMeta
from ..utils import *
__all__ = ['Callback', 'PeriodicCallback', 'TrainCallback', 'TestCallback']
__all__ = ['Callback', 'PeriodicCallback', 'TrainCallbackType', 'TestCallbackType']
class TrainCallback(object):
class TrainCallbackType(object):
pass
class TestCallback(object):
class TestCallbackType(object):
pass
class Callback(object):
""" Base class for all callbacks """
__metaclass__ = ABCMeta
type = TrainCallback()
""" The graph that this callback should run on.
Either TrainCallback or TestCallback
type = TrainCallbackType()
""" Determine the graph that this callback should run on.
Either `TrainCallbackType()` or `TestCallbackType()`.
Default is `TrainCallbackType()`
"""
def before_train(self, trainer):
"""
Called before starting iterative training.
:param trainer: a :class:`train.Trainer` instance
"""
self.trainer = trainer
self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch
self._before_train()
def _before_train(self):
"""
Called before starting iterative training
"""
pass
def after_train(self):
"""
Called after training.
"""
self._after_train()
def _after_train(self):
"""
Called after training
"""
pass
def trigger_step(self):
"""
Callback to be triggered after every step (every backpropagation)
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
@property
def global_step(self):
"""
Access the global step value of this training.
"""
return self.trainer.global_step
def trigger_epoch(self):
"""
epoch_num is the number of epoch finished.
Triggered after every epoch.
In this function, self.epoch_num would be the number of epoch finished.
"""
self._trigger_epoch()
self.epoch_num += 1
def _trigger_epoch(self):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
pass
class PeriodicCallback(Callback):
"""
A callback to be triggered after every `period` epochs.
"""
def __init__(self, period):
self.period = period
"""
:param period: int
"""
self.period = int(period)
def _trigger_epoch(self):
if self.epoch_num % self.period == 0:
......
......@@ -12,7 +12,15 @@ from ..utils import *
__all__ = ['PeriodicSaver']
class PeriodicSaver(PeriodicCallback):
"""
Save the model to logger directory.
"""
def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
"""
:param period: number of epochs to save models.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
super(PeriodicSaver, self).__init__(period)
self.keep_recent = keep_recent
self.keep_freq = keep_freq
......
......@@ -9,21 +9,30 @@ import numpy as np
from .base import Callback
from ..utils import logger
from ..tfutils import get_op_var_name
__all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback):
"""
Dump a variable to image(s) after every epoch.
"""
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
"""
map_func: map the value of the variable to an image or list of images, default to identity
images should have shape [h, w] or [h, w, c].
scale: a multiplier on pixel values, applied after map_func. default to 255
clip: clip the result to [0, 255]
:param var_name: the name of the variable.
:param prefix: the filename prefix for saved images. Default is the op name.
:param map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity
:param scale: a multiplier on pixel values, applied after map_func. default to 255
:param clip: whether to clip the result to [0, 255]
"""
self.var_name = var_name
op_name, self.var_name = get_op_var_name(var_name)
self.func = map_func
if prefix is None:
self.prefix = self.var_name
self.prefix = op_name
else:
self.prefix = prefix
self.log_dir = logger.LOG_DIR
......
......@@ -6,7 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager
import time
from .base import Callback, TrainCallback, TestCallback
from .base import Callback, TrainCallbackType, TestCallbackType
from .summary import *
from ..utils import *
......@@ -91,11 +91,17 @@ class TestCallbackContext(object):
yield
class Callbacks(Callback):
"""
A container to hold all callbacks, and execute them in the right order and proper session.
"""
def __init__(self, cbs):
"""
:param cbs: a list of `Callbacks`
"""
# check type
for cb in cbs:
assert isinstance(cb, Callback), cb.__class__
if not isinstance(cb.type, (TrainCallback, TestCallback)):
if not isinstance(cb.type, (TrainCallbackType, TestCallbackType)):
raise ValueError(
"Unknown callback running graph {}!".format(str(cb.type)))
......@@ -104,7 +110,7 @@ class Callbacks(Callback):
def _before_train(self):
for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
if isinstance(cb.type, TrainCallbackType):
cb.before_train(self.trainer)
else:
with self.test_callback_context.before_train_context(self.trainer):
......@@ -116,7 +122,7 @@ class Callbacks(Callback):
def trigger_step(self):
for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
if isinstance(cb.type, TrainCallbackType):
cb.trigger_step()
# test callback don't have trigger_step
......@@ -125,7 +131,7 @@ class Callbacks(Callback):
test_sess_restored = False
for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
if isinstance(cb.type, TrainCallbackType):
with tm.timed_callback(type(cb).__name__):
cb.trigger_epoch()
else:
......
......@@ -15,10 +15,17 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter']
class HyperParamSetter(Callback):
"""
Base class to set hyperparameters after every epoch.
"""
__metaclass__ = ABCMeta
# TODO maybe support InputVar?
def __init__(self, var_name, shape=[]):
"""
:param var_name: name of the variable
:param shape: shape of the variable
"""
self.op_name, self.var_name = get_op_var_name(var_name)
self.shape = shape
self.last_value = None
......@@ -37,6 +44,9 @@ class HyperParamSetter(Callback):
self.assign_op = self.var.assign(self.val_holder)
def get_current_value(self):
"""
:returns: the value to assign to the variable now.
"""
ret = self._get_current_value()
if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} is changed to {}".format(
......@@ -54,10 +64,13 @@ class HyperParamSetter(Callback):
self.assign_op.eval(feed_dict={self.val_holder:v})
class HumanHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters manually by modifying a file.
"""
def __init__(self, var_name, file_name):
"""
read value from file_name.
file_name: each line in the file is a k:v pair
:param var_name: name of the variable.
:param file_name: a file containing the value of the variable. Each line in the file is a k:v pair
"""
self.file_name = file_name
super(HumanHyperParamSetter, self).__init__(var_name)
......@@ -77,9 +90,12 @@ class HumanHyperParamSetter(HyperParamSetter):
return None
class ScheduledHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by a predefined schedule.
"""
def __init__(self, var_name, schedule):
"""
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
"""
schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0))
......
......@@ -14,8 +14,14 @@ from ..utils import *
__all__ = ['StatHolder', 'StatPrinter']
class StatHolder(object):
def __init__(self, log_dir, print_tag=None):
self.set_print_tag(print_tag)
"""
A holder to keep all statistics aside from tensorflow events.
"""
def __init__(self, log_dir):
"""
:param log_dir: directory to save the stats.
"""
self.set_print_tag([])
self.stat_now = {}
self.log_dir = log_dir
......@@ -28,12 +34,23 @@ class StatHolder(object):
self.stat_history = []
def add_stat(self, k, v):
"""
Add a stat.
:param k: name
:param v: value
"""
self.stat_now[k] = v
def set_print_tag(self, print_tag):
"""
Set name of stats to print.
"""
self.print_tag = None if print_tag is None else set(print_tag)
def finalize(self):
"""
Called after finishing adding stats. Will print and write stats to disk.
"""
self._print_stat()
self.stat_history.append(self.stat_now)
self.stat_now = {}
......@@ -51,9 +68,13 @@ class StatHolder(object):
os.rename(tmp_filename, self.filename)
class StatPrinter(Callback):
"""
Control what stats to print.
"""
def __init__(self, print_tag=None):
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
"""
:param print_tag : a list of regex to match scalar summary to print.
If None, will print all scalar tags
"""
self.print_tag = print_tag
......
......@@ -10,16 +10,22 @@ from six.moves import zip
from ..utils import *
from ..utils.stat import *
from ..tfutils.summary import *
from .base import PeriodicCallback, Callback, TestCallback
from .base import PeriodicCallback, Callback, TestCallbackType
__all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter']
class ValidationCallback(PeriodicCallback):
type = TestCallback()
"""
Base class for validation callbacks.
"""
type = TestCallbackType()
def __init__(self, ds, prefix, period=1):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super(ValidationCallback, self).__init__(period)
self.ds = ds
self.prefix = prefix
......@@ -29,6 +35,9 @@ class ValidationCallback(PeriodicCallback):
self._find_output_vars()
def get_tensor(self, name):
"""
Get tensor from graph.
"""
return self.graph.get_tensor_by_name(name)
@abstractmethod
......@@ -63,6 +72,12 @@ class ValidationStatPrinter(ValidationCallback):
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
def __init__(self, ds, names_to_print, prefix='validation', period=1):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super(ValidationStatPrinter, self).__init__(ds, prefix, period)
self.names = names_to_print
......@@ -88,9 +103,9 @@ class ValidationStatPrinter(ValidationCallback):
class ValidationError(ValidationCallback):
"""
Validate the accuracy from a 'wrong' variable
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
Validate the accuracy from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
......@@ -100,6 +115,10 @@ class ValidationError(ValidationCallback):
def __init__(self, ds, prefix='validation',
period=1,
wrong_var_name='wrong:0'):
"""
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super(ValidationError, self).__init__(ds, prefix, period)
self.wrong_var_name = wrong_var_name
......
......@@ -15,6 +15,8 @@ from ..base import DataFlow
__all__ = ['SVHNDigit']
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class SVHNDigit(DataFlow):
"""
SVHN Cropped Digit Dataset
......@@ -25,7 +27,7 @@ class SVHNDigit(DataFlow):
def __init__(self, name, data_dir=None, shuffle=True):
"""
name: 'train', 'test', or 'extra'
data_dir: a directory containing {train,test,extra}_32x32.mat
data_dir: a directory containing the original {train,test,extra}_32x32.mat
"""
self.shuffle = shuffle
self.rng = get_rng(self)
......@@ -40,8 +42,7 @@ class SVHNDigit(DataFlow):
assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \
"File {} not found! Please download it from \
http://ufldl.stanford.edu/housenumbers/".format(filename)
"File {} not found! Please download it from {}.".format(filename, SVHN_URL)
logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename)
self.X = data['X'].transpose(3,0,1,2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment