Commit a6f88814 authored by Yuxin Wu's avatar Yuxin Wu

docs in callbacks

parent 817bb882
...@@ -10,67 +10,83 @@ from abc import abstractmethod, ABCMeta ...@@ -10,67 +10,83 @@ from abc import abstractmethod, ABCMeta
from ..utils import * from ..utils import *
__all__ = ['Callback', 'PeriodicCallback', 'TrainCallback', 'TestCallback'] __all__ = ['Callback', 'PeriodicCallback', 'TrainCallbackType', 'TestCallbackType']
class TrainCallback(object): class TrainCallbackType(object):
pass pass
class TestCallback(object): class TestCallbackType(object):
pass pass
class Callback(object): class Callback(object):
""" Base class for all callbacks """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
type = TrainCallback() type = TrainCallbackType()
""" The graph that this callback should run on. """ Determine the graph that this callback should run on.
Either TrainCallback or TestCallback Either `TrainCallbackType()` or `TestCallbackType()`.
Default is `TrainCallbackType()`
""" """
def before_train(self, trainer): def before_train(self, trainer):
"""
Called before starting iterative training.
:param trainer: a :class:`train.Trainer` instance
"""
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch self.epoch_num = self.trainer.config.starting_epoch
self._before_train() self._before_train()
def _before_train(self): def _before_train(self):
""" pass
Called before starting iterative training
"""
def after_train(self): def after_train(self):
"""
Called after training.
"""
self._after_train() self._after_train()
def _after_train(self): def _after_train(self):
""" pass
Called after training
"""
def trigger_step(self): def trigger_step(self):
""" """
Callback to be triggered after every step (every backpropagation) Callback to be triggered after every step (every backpropagation)
Could be useful to apply some tricks on parameters (clipping, low-rank, etc) Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
""" """
@property @property
def global_step(self): def global_step(self):
"""
Access the global step value of this training.
"""
return self.trainer.global_step return self.trainer.global_step
def trigger_epoch(self): def trigger_epoch(self):
""" """
epoch_num is the number of epoch finished. Triggered after every epoch.
In this function, self.epoch_num would be the number of epoch finished.
""" """
self._trigger_epoch() self._trigger_epoch()
self.epoch_num += 1 self.epoch_num += 1
def _trigger_epoch(self): def _trigger_epoch(self):
""" pass
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class PeriodicCallback(Callback): class PeriodicCallback(Callback):
"""
A callback to be triggered after every `period` epochs.
"""
def __init__(self, period): def __init__(self, period):
self.period = period """
:param period: int
"""
self.period = int(period)
def _trigger_epoch(self): def _trigger_epoch(self):
if self.epoch_num % self.period == 0: if self.epoch_num % self.period == 0:
......
...@@ -12,7 +12,15 @@ from ..utils import * ...@@ -12,7 +12,15 @@ from ..utils import *
__all__ = ['PeriodicSaver'] __all__ = ['PeriodicSaver']
class PeriodicSaver(PeriodicCallback): class PeriodicSaver(PeriodicCallback):
"""
Save the model to logger directory.
"""
def __init__(self, period=1, keep_recent=10, keep_freq=0.5): def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
"""
:param period: number of epochs to save models.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
super(PeriodicSaver, self).__init__(period) super(PeriodicSaver, self).__init__(period)
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.keep_freq = keep_freq self.keep_freq = keep_freq
......
...@@ -9,21 +9,30 @@ import numpy as np ...@@ -9,21 +9,30 @@ import numpy as np
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils import get_op_var_name
__all__ = ['DumpParamAsImage'] __all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback): class DumpParamAsImage(Callback):
"""
Dump a variable to image(s) after every epoch.
"""
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False): def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
""" """
map_func: map the value of the variable to an image or list of images, default to identity :param var_name: the name of the variable.
images should have shape [h, w] or [h, w, c].
scale: a multiplier on pixel values, applied after map_func. default to 255 :param prefix: the filename prefix for saved images. Default is the op name.
clip: clip the result to [0, 255]
:param map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity
:param scale: a multiplier on pixel values, applied after map_func. default to 255
:param clip: whether to clip the result to [0, 255]
""" """
self.var_name = var_name op_name, self.var_name = get_op_var_name(var_name)
self.func = map_func self.func = map_func
if prefix is None: if prefix is None:
self.prefix = self.var_name self.prefix = op_name
else: else:
self.prefix = prefix self.prefix = prefix
self.log_dir = logger.LOG_DIR self.log_dir = logger.LOG_DIR
......
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
import time import time
from .base import Callback, TrainCallback, TestCallback from .base import Callback, TrainCallbackType, TestCallbackType
from .summary import * from .summary import *
from ..utils import * from ..utils import *
...@@ -91,11 +91,17 @@ class TestCallbackContext(object): ...@@ -91,11 +91,17 @@ class TestCallbackContext(object):
yield yield
class Callbacks(Callback): class Callbacks(Callback):
"""
A container to hold all callbacks, and execute them in the right order and proper session.
"""
def __init__(self, cbs): def __init__(self, cbs):
"""
:param cbs: a list of `Callbacks`
"""
# check type # check type
for cb in cbs: for cb in cbs:
assert isinstance(cb, Callback), cb.__class__ assert isinstance(cb, Callback), cb.__class__
if not isinstance(cb.type, (TrainCallback, TestCallback)): if not isinstance(cb.type, (TrainCallbackType, TestCallbackType)):
raise ValueError( raise ValueError(
"Unknown callback running graph {}!".format(str(cb.type))) "Unknown callback running graph {}!".format(str(cb.type)))
...@@ -104,7 +110,7 @@ class Callbacks(Callback): ...@@ -104,7 +110,7 @@ class Callbacks(Callback):
def _before_train(self): def _before_train(self):
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallback): if isinstance(cb.type, TrainCallbackType):
cb.before_train(self.trainer) cb.before_train(self.trainer)
else: else:
with self.test_callback_context.before_train_context(self.trainer): with self.test_callback_context.before_train_context(self.trainer):
...@@ -116,7 +122,7 @@ class Callbacks(Callback): ...@@ -116,7 +122,7 @@ class Callbacks(Callback):
def trigger_step(self): def trigger_step(self):
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallback): if isinstance(cb.type, TrainCallbackType):
cb.trigger_step() cb.trigger_step()
# test callback don't have trigger_step # test callback don't have trigger_step
...@@ -125,7 +131,7 @@ class Callbacks(Callback): ...@@ -125,7 +131,7 @@ class Callbacks(Callback):
test_sess_restored = False test_sess_restored = False
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallback): if isinstance(cb.type, TrainCallbackType):
with tm.timed_callback(type(cb).__name__): with tm.timed_callback(type(cb).__name__):
cb.trigger_epoch() cb.trigger_epoch()
else: else:
......
...@@ -15,10 +15,17 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter', ...@@ -15,10 +15,17 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter'] 'ScheduledHyperParamSetter']
class HyperParamSetter(Callback): class HyperParamSetter(Callback):
"""
Base class to set hyperparameters after every epoch.
"""
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
# TODO maybe support InputVar? # TODO maybe support InputVar?
def __init__(self, var_name, shape=[]): def __init__(self, var_name, shape=[]):
"""
:param var_name: name of the variable
:param shape: shape of the variable
"""
self.op_name, self.var_name = get_op_var_name(var_name) self.op_name, self.var_name = get_op_var_name(var_name)
self.shape = shape self.shape = shape
self.last_value = None self.last_value = None
...@@ -37,6 +44,9 @@ class HyperParamSetter(Callback): ...@@ -37,6 +44,9 @@ class HyperParamSetter(Callback):
self.assign_op = self.var.assign(self.val_holder) self.assign_op = self.var.assign(self.val_holder)
def get_current_value(self): def get_current_value(self):
"""
:returns: the value to assign to the variable now.
"""
ret = self._get_current_value() ret = self._get_current_value()
if ret is not None and ret != self.last_value: if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} is changed to {}".format( logger.info("{} at epoch {} is changed to {}".format(
...@@ -54,10 +64,13 @@ class HyperParamSetter(Callback): ...@@ -54,10 +64,13 @@ class HyperParamSetter(Callback):
self.assign_op.eval(feed_dict={self.val_holder:v}) self.assign_op.eval(feed_dict={self.val_holder:v})
class HumanHyperParamSetter(HyperParamSetter): class HumanHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters manually by modifying a file.
"""
def __init__(self, var_name, file_name): def __init__(self, var_name, file_name):
""" """
read value from file_name. :param var_name: name of the variable.
file_name: each line in the file is a k:v pair :param file_name: a file containing the value of the variable. Each line in the file is a k:v pair
""" """
self.file_name = file_name self.file_name = file_name
super(HumanHyperParamSetter, self).__init__(var_name) super(HumanHyperParamSetter, self).__init__(var_name)
...@@ -77,9 +90,12 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -77,9 +90,12 @@ class HumanHyperParamSetter(HyperParamSetter):
return None return None
class ScheduledHyperParamSetter(HyperParamSetter): class ScheduledHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by a predefined schedule.
"""
def __init__(self, var_name, schedule): def __init__(self, var_name, schedule):
""" """
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...] :param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
""" """
schedule = [(int(a), float(b)) for a, b in schedule] schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0)) self.schedule = sorted(schedule, key=operator.itemgetter(0))
......
...@@ -14,8 +14,14 @@ from ..utils import * ...@@ -14,8 +14,14 @@ from ..utils import *
__all__ = ['StatHolder', 'StatPrinter'] __all__ = ['StatHolder', 'StatPrinter']
class StatHolder(object): class StatHolder(object):
def __init__(self, log_dir, print_tag=None): """
self.set_print_tag(print_tag) A holder to keep all statistics aside from tensorflow events.
"""
def __init__(self, log_dir):
"""
:param log_dir: directory to save the stats.
"""
self.set_print_tag([])
self.stat_now = {} self.stat_now = {}
self.log_dir = log_dir self.log_dir = log_dir
...@@ -28,12 +34,23 @@ class StatHolder(object): ...@@ -28,12 +34,23 @@ class StatHolder(object):
self.stat_history = [] self.stat_history = []
def add_stat(self, k, v): def add_stat(self, k, v):
"""
Add a stat.
:param k: name
:param v: value
"""
self.stat_now[k] = v self.stat_now[k] = v
def set_print_tag(self, print_tag): def set_print_tag(self, print_tag):
"""
Set name of stats to print.
"""
self.print_tag = None if print_tag is None else set(print_tag) self.print_tag = None if print_tag is None else set(print_tag)
def finalize(self): def finalize(self):
"""
Called after finishing adding stats. Will print and write stats to disk.
"""
self._print_stat() self._print_stat()
self.stat_history.append(self.stat_now) self.stat_history.append(self.stat_now)
self.stat_now = {} self.stat_now = {}
...@@ -51,9 +68,13 @@ class StatHolder(object): ...@@ -51,9 +68,13 @@ class StatHolder(object):
os.rename(tmp_filename, self.filename) os.rename(tmp_filename, self.filename)
class StatPrinter(Callback): class StatPrinter(Callback):
"""
Control what stats to print.
"""
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
""" print_tag : a list of regex to match scalar summary to print """
if None, will print all scalar tags :param print_tag : a list of regex to match scalar summary to print.
If None, will print all scalar tags
""" """
self.print_tag = print_tag self.print_tag = print_tag
......
...@@ -10,16 +10,22 @@ from six.moves import zip ...@@ -10,16 +10,22 @@ from six.moves import zip
from ..utils import * from ..utils import *
from ..utils.stat import * from ..utils.stat import *
from ..tfutils.summary import * from ..tfutils.summary import *
from .base import PeriodicCallback, Callback, TestCallback from .base import PeriodicCallback, Callback, TestCallbackType
__all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter'] __all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter']
class ValidationCallback(PeriodicCallback): class ValidationCallback(PeriodicCallback):
type = TestCallback()
""" """
Base class for validation callbacks. Base class for validation callbacks.
""" """
type = TestCallbackType()
def __init__(self, ds, prefix, period=1): def __init__(self, ds, prefix, period=1):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super(ValidationCallback, self).__init__(period) super(ValidationCallback, self).__init__(period)
self.ds = ds self.ds = ds
self.prefix = prefix self.prefix = prefix
...@@ -29,6 +35,9 @@ class ValidationCallback(PeriodicCallback): ...@@ -29,6 +35,9 @@ class ValidationCallback(PeriodicCallback):
self._find_output_vars() self._find_output_vars()
def get_tensor(self, name): def get_tensor(self, name):
"""
Get tensor from graph.
"""
return self.graph.get_tensor_by_name(name) return self.graph.get_tensor_by_name(name)
@abstractmethod @abstractmethod
...@@ -63,6 +72,12 @@ class ValidationStatPrinter(ValidationCallback): ...@@ -63,6 +72,12 @@ class ValidationStatPrinter(ValidationCallback):
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set. The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
""" """
def __init__(self, ds, names_to_print, prefix='validation', period=1): def __init__(self, ds, names_to_print, prefix='validation', period=1):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super(ValidationStatPrinter, self).__init__(ds, prefix, period) super(ValidationStatPrinter, self).__init__(ds, prefix, period)
self.names = names_to_print self.names = names_to_print
...@@ -88,9 +103,9 @@ class ValidationStatPrinter(ValidationCallback): ...@@ -88,9 +103,9 @@ class ValidationStatPrinter(ValidationCallback):
class ValidationError(ValidationCallback): class ValidationError(ValidationCallback):
""" """
Validate the accuracy from a 'wrong' variable Validate the accuracy from a `wrong` variable
wrong_var: integer, number of failed samples in this batch
ds: batched dataset The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error, This callback produce the "true" error,
taking account of the fact that batches might not have the same size in taking account of the fact that batches might not have the same size in
...@@ -100,6 +115,10 @@ class ValidationError(ValidationCallback): ...@@ -100,6 +115,10 @@ class ValidationError(ValidationCallback):
def __init__(self, ds, prefix='validation', def __init__(self, ds, prefix='validation',
period=1, period=1,
wrong_var_name='wrong:0'): wrong_var_name='wrong:0'):
"""
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super(ValidationError, self).__init__(ds, prefix, period) super(ValidationError, self).__init__(ds, prefix, period)
self.wrong_var_name = wrong_var_name self.wrong_var_name = wrong_var_name
......
...@@ -15,6 +15,8 @@ from ..base import DataFlow ...@@ -15,6 +15,8 @@ from ..base import DataFlow
__all__ = ['SVHNDigit'] __all__ = ['SVHNDigit']
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class SVHNDigit(DataFlow): class SVHNDigit(DataFlow):
""" """
SVHN Cropped Digit Dataset SVHN Cropped Digit Dataset
...@@ -25,7 +27,7 @@ class SVHNDigit(DataFlow): ...@@ -25,7 +27,7 @@ class SVHNDigit(DataFlow):
def __init__(self, name, data_dir=None, shuffle=True): def __init__(self, name, data_dir=None, shuffle=True):
""" """
name: 'train', 'test', or 'extra' name: 'train', 'test', or 'extra'
data_dir: a directory containing {train,test,extra}_32x32.mat data_dir: a directory containing the original {train,test,extra}_32x32.mat
""" """
self.shuffle = shuffle self.shuffle = shuffle
self.rng = get_rng(self) self.rng = get_rng(self)
...@@ -40,8 +42,7 @@ class SVHNDigit(DataFlow): ...@@ -40,8 +42,7 @@ class SVHNDigit(DataFlow):
assert name in ['train', 'test', 'extra'], name assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat') filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \ assert os.path.isfile(filename), \
"File {} not found! Please download it from \ "File {} not found! Please download it from {}.".format(filename, SVHN_URL)
http://ufldl.stanford.edu/housenumbers/".format(filename)
logger.info("Loading {} ...".format(filename)) logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename) data = scipy.io.loadmat(filename)
self.X = data['X'].transpose(3,0,1,2) self.X = data['X'].transpose(3,0,1,2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment