Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b5f8c73a
Commit
b5f8c73a
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
sphinx doc for callbacks
parent
b5acbf3a
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
312 additions
and
163 deletions
+312
-163
examples/DoReFa-Net/README.md
examples/DoReFa-Net/README.md
+1
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+20
-6
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+10
-6
tensorpack/callbacks/dispatcher.py
tensorpack/callbacks/dispatcher.py
+0
-31
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+11
-10
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+10
-5
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+4
-2
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+41
-28
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+33
-8
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+112
-39
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+35
-13
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+35
-14
No files found.
examples/DoReFa-Net/README.md
View file @
b5f8c73a
...
...
@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[
google drive
](
https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
)
.
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The
binary-weight 4-bit-activation ResNet-18
model has 59.2% top-1 validation error.
The
__binary-weight 4-bit-activation ResNet-18__
model has 59.2% top-1 validation error.
Alternative link to this page:
[
http://dorefa.net
](
http://dorefa.net
)
...
...
tensorpack/callbacks/base.py
View file @
b5f8c73a
...
...
@@ -27,7 +27,8 @@ class Callback(object):
Called before finalizing the graph.
Use this callback to setup some ops used in the callback.
:param trainer: :class:`train.Trainer` instance
Args:
trainer(Trainer): the trainer which calls the callback
"""
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
...
...
@@ -59,7 +60,7 @@ class Callback(object):
"""
Triggered after every epoch.
In this function,
self.epoch_num
would be the number of epoch finished.
In this function,
``self.epoch_num``
would be the number of epoch finished.
"""
self
.
epoch_num
+=
1
self
.
_trigger_epoch
()
...
...
@@ -72,8 +73,15 @@ class Callback(object):
class
ProxyCallback
(
Callback
):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
"""
def
__init__
(
self
,
cb
):
"""
Args:
cb(Callback): the underlying callback
"""
self
.
cb
=
cb
def
_before_train
(
self
):
...
...
@@ -94,14 +102,20 @@ class ProxyCallback(Callback):
class
PeriodicCallback
(
ProxyCallback
):
"""
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
Wrap a callback so that it is triggered after every ``period`` epochs.
Doesn't work for ``trigger_step``.
"""
def
__init__
(
self
,
cb
,
period
):
"""
:param cb: a `Callback`
:param period: int
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
self
.
period
=
int
(
period
)
...
...
tensorpack/callbacks/concurrency.py
View file @
b5f8c73a
...
...
@@ -11,15 +11,19 @@ __all__ = ['StartProcOrThread']
class
StartProcOrThread
(
Callback
):
"""
Start some threads or processes before training.
"""
def
__init__
(
self
,
procs_threads
):
def
__init__
(
self
,
startable
):
"""
Start extra threads and processes before training
:param procs_threads: list of processes or threads
Args:
startable(list): list of processes or threads which have ``start()`` method.
Can also be a single instance of process of thread.
"""
if
not
isinstance
(
procs_threads
,
list
):
procs_threads
=
[
procs_threads
]
self
.
_procs_threads
=
procs_threads
if
not
isinstance
(
startable
,
list
):
startable
=
[
startable
]
self
.
_procs_threads
=
startable
def
_before_train
(
self
):
logger
.
info
(
"Starting "
+
...
...
tensorpack/callbacks/dispatcher.py
deleted
100644 → 0
View file @
b5acbf3a
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dispatcher.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..tfutils.common
import
get_op_tensor_name
__all__
=
[
'OutputTensorDispatcer'
]
class
OutputTensorDispatcer
(
object
):
def
__init__
(
self
):
self
.
_names
=
[]
self
.
_idxs
=
[]
def
add_entry
(
self
,
names
):
v
=
[]
for
n
in
names
:
tensorname
=
get_op_tensor_name
(
n
)[
1
]
if
tensorname
in
self
.
_names
:
v
.
append
(
self
.
_names
.
index
(
tensorname
))
else
:
self
.
_names
.
append
(
tensorname
)
v
.
append
(
len
(
self
.
_names
)
-
1
)
self
.
_idxs
.
append
(
v
)
def
get_all_names
(
self
):
return
self
.
_names
def
get_idx_for_each_entry
(
self
):
return
self
.
_idxs
tensorpack/callbacks/dump.py
View file @
b5f8c73a
...
...
@@ -8,26 +8,27 @@ import numpy as np
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils
import
get_op_
va
r_name
from
..tfutils
import
get_op_
tenso
r_name
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Callback
):
"""
Dump a variable to image(s)
after every epoch to logger.LOG_DIR
.
Dump a variable to image(s)
to ``logger.LOG_DIR`` after every epoch
.
"""
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
"""
:param var_name: the name of the variable.
:param prefix: the filename prefix for saved images. Default is the op name.
:param map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity
:param scale: a multiplier on pixel values, applied after map_func. default to 255
:param clip: whether to clip the result to [0, 255]
Args:
var_name (str): the name of the variable.
prefix (str): the filename prefix for saved images. Defaults to the Op name.
map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity.
scale (float): a multiplier on pixel values, applied after map_func.
clip (bool): whether to clip the result to [0, 255].
"""
op_name
,
self
.
var_name
=
get_op_
va
r_name
(
var_name
)
op_name
,
self
.
var_name
=
get_op_
tenso
r_name
(
var_name
)
self
.
func
=
map_func
if
prefix
is
None
:
self
.
prefix
=
op_name
...
...
@@ -45,7 +46,7 @@ class DumpParamAsImage(Callback):
val
=
self
.
trainer
.
sess
.
run
(
self
.
var
)
if
self
.
func
is
not
None
:
val
=
self
.
func
(
val
)
if
isinstance
(
val
,
list
):
if
isinstance
(
val
,
list
)
or
val
.
ndim
==
4
:
for
idx
,
im
in
enumerate
(
val
):
self
.
_dump_image
(
im
,
idx
)
else
:
...
...
tensorpack/callbacks/graph.py
View file @
b5f8c73a
...
...
@@ -11,13 +11,19 @@ __all__ = ['RunOp']
class
RunOp
(
Callback
):
""" Run an
op periodically
"""
""" Run an
Op.
"""
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
"""
:param setup_func: a function that returns the op in the graph
:param run_before: run the op before training
:param run_epoch: run the op on every epoch trigger
Args:
setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training
run_epoch (bool): run the Op on every epoch trigger
Examples:
The `DQN Example
<https://github.com/ppwwyyxx/tensorpack/blob/master/examples/Atari2600/DQN.py#L182>`_
uses this callback to update target network.
"""
self
.
setup_func
=
setup_func
self
.
run_before
=
run_before
...
...
@@ -25,7 +31,6 @@ class RunOp(Callback):
def
_setup_graph
(
self
):
self
.
_op
=
self
.
setup_func
()
# self._op_name = self._op.name
def
_before_train
(
self
):
if
self
.
run_before
:
...
...
tensorpack/callbacks/group.py
View file @
b5f8c73a
...
...
@@ -44,12 +44,14 @@ class CallbackTimeLogger(object):
class
Callbacks
(
Callback
):
"""
A container to hold all callbacks, and execute them in the right order and proper session.
A container to hold all callbacks, and execute them in the right order
(e.g. :class:`StatPrinter` will be executed at last).
"""
def
__init__
(
self
,
cbs
):
"""
:param cbs: a list of `Callbacks`
Args:
cbs(list): a list of :class:`Callback` instances.
"""
# check type
for
cb
in
cbs
:
...
...
tensorpack/callbacks/inference.py
View file @
b5f8c73a
...
...
@@ -12,12 +12,13 @@ from ..utils import logger
from
..utils.stats
import
RatioCounter
,
BinaryStatistics
from
..tfutils
import
get_op_var_name
__all__
=
[
'
ClassificationErro
r'
,
'
ScalarStats'
,
'Inference
r'
,
'BinaryClassificationStats'
]
__all__
=
[
'
ScalarStats'
,
'Inference
r'
,
'
ClassificationErro
r'
,
'BinaryClassificationStats'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
Inferencer
(
object
):
""" Base class of Inferencer. To be used with :class:`InferenceRunner`. """
def
before_inference
(
self
):
"""
...
...
@@ -30,7 +31,11 @@ class Inferencer(object):
def
datapoint
(
self
,
output
):
"""
Called after complete running every data point
Called after each new datapoint finished the forward inference.
Args:
output(list): list of output this inferencer needs. Has the same
length as ``self.get_output_tensors()``.
"""
self
.
_datapoint
(
output
)
...
...
@@ -41,8 +46,8 @@ class Inferencer(object):
def
after_inference
(
self
):
"""
Called after a round of inference ends.
Returns a dict of statistics which will be logged by the
InferenceRunner
.
The inferencer needs to handle other
kind of logging by their own
.
Returns a dict of statistics which will be logged by the
:class:`InferenceRunner`
.
The inferencer needs to handle other
type of logging by itself, if there is any
.
"""
return
self
.
_after_inference
()
...
...
@@ -51,7 +56,7 @@ class Inferencer(object):
def
get_output_tensors
(
self
):
"""
Return a list of tensor names
needed for this inference
Return a list of tensor names
this inferencer needed.
"""
return
self
.
_get_output_tensors
()
...
...
@@ -62,15 +67,16 @@ class Inferencer(object):
class
ScalarStats
(
Inferencer
):
"""
Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the inference dataflow.
Statistics of some scalar tensor.
The value will be averaged over all given datapoints.
"""
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
"""
:param names_to_print: list of names of tensors, or just a name
:param prefix: an optional prefix for logging
Args:
names_to_print(list or str): list of names or just one name. The
corresponding tensors have to be scalar.
prefix(str): a prefix for logging
"""
if
not
isinstance
(
names_to_print
,
list
):
self
.
names
=
[
names_to_print
]
...
...
@@ -85,6 +91,8 @@ class ScalarStats(Inferencer):
self
.
stats
=
[]
def
_datapoint
(
self
,
output
):
for
o
in
output
:
assert
isinstance
(
o
,
(
float
,
np
.
float32
)),
type
(
o
)
self
.
stats
.
append
(
output
)
def
_after_inference
(
self
):
...
...
@@ -101,24 +109,27 @@ class ScalarStats(Inferencer):
class
ClassificationError
(
Inferencer
):
"""
Compute classification error in batch mode, from a `
wrong` variable
Compute classification error in batch mode, from a `
`wrong`` tensor.
The `
wrong` tensor is supposed to be an 0/1 integer
vector containing
whether each sample in the batch is
incorrectly
classified.
You can use `
tf.nn.in_top_k` to produce this vector record top-k error as well
.
The `
`wrong`` tensor is supposed to be an binary
vector containing
whether each sample in the batch is
*incorrectly*
classified.
You can use `
`tf.nn.in_top_k`` to produce this vector
.
This
callback produce
the "true" error,
This
Inferencer produces
the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
Therefore the result
is
different from averaging the error rate of each batch.
Therefore the result
can be
different from averaging the error rate of each batch.
"""
def
__init__
(
self
,
wrong_
va
r_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
def
__init__
(
self
,
wrong_
tenso
r_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
"""
:param wrong_var_name: name of the `wrong` variable
:param summary_name: the name for logging
Args:
wrong_tensor_name(str): name of the ``wrong`` tensor.
The default is the same as the default output name of
:meth:`prediction_incorrect`.
summary_name(str): the name for logging.
"""
self
.
wrong_var_name
=
wrong_
va
r_name
self
.
wrong_var_name
=
wrong_
tenso
r_name
self
.
summary_name
=
summary_name
def
_get_output_tensors
(
self
):
...
...
@@ -144,21 +155,23 @@ class ClassificationError(Inferencer):
class
BinaryClassificationStats
(
Inferencer
):
""" Compute precision/recall in binary classification, given the
"""
Compute precision / recall in binary classification, given the
prediction vector and the label vector.
"""
def
__init__
(
self
,
pred_
var_name
,
label_va
r_name
,
summary_prefix
=
'val'
):
def
__init__
(
self
,
pred_
tensor_name
,
label_tenso
r_name
,
summary_prefix
=
'val'
):
"""
:param pred_var_name: name of the 0/1 prediction tensor.
:param label_var_name: name of the 0/1 label tensor.
Args:
pred_tensor_name(str): name of the 0/1 prediction tensor.
label_tensor_name(str): name of the 0/1 label tensor.
"""
self
.
pred_
var_name
=
pred_va
r_name
self
.
label_
var_name
=
label_va
r_name
self
.
pred_
tensor_name
=
pred_tenso
r_name
self
.
label_
tensor_name
=
label_tenso
r_name
self
.
prefix
=
summary_prefix
def
_get_output_tensors
(
self
):
return
[
self
.
pred_
var_name
,
self
.
label_va
r_name
]
return
[
self
.
pred_
tensor_name
,
self
.
label_tenso
r_name
]
def
_before_inference
(
self
):
self
.
stat
=
BinaryStatistics
()
...
...
tensorpack/callbacks/inference_runner.py
View file @
b5f8c73a
...
...
@@ -11,13 +11,36 @@ from six.moves import zip, range
from
..dataflow
import
DataFlow
from
.base
import
Callback
from
.inference
import
Inferencer
from
.dispatcher
import
OutputTensorDispatcer
from
..utils
import
logger
,
get_tqdm
from
..tfutils.common
import
get_op_tensor_name
from
..train.input_data
import
FeedfreeInput
__all__
=
[
'InferenceRunner'
]
class
OutputTensorDispatcer
(
object
):
def
__init__
(
self
):
self
.
_names
=
[]
self
.
_idxs
=
[]
def
add_entry
(
self
,
names
):
v
=
[]
for
n
in
names
:
tensorname
=
get_op_tensor_name
(
n
)[
1
]
if
tensorname
in
self
.
_names
:
v
.
append
(
self
.
_names
.
index
(
tensorname
))
else
:
self
.
_names
.
append
(
tensorname
)
v
.
append
(
len
(
self
.
_names
)
-
1
)
self
.
_idxs
.
append
(
v
)
def
get_all_names
(
self
):
return
self
.
_names
def
get_idx_for_each_entry
(
self
):
return
self
.
_idxs
def
summary_inferencer
(
trainer
,
infs
):
for
inf
in
infs
:
ret
=
inf
.
after_inference
()
...
...
@@ -32,17 +55,19 @@ def summary_inferencer(trainer, infs):
class
InferenceRunner
(
Callback
):
"""
A callback that runs different kinds of inferencer.
A callback that runs a list of :class:`Inferencer` on some
:class:`DataFlow`.
"""
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
_
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
def
__init__
(
self
,
ds
,
infs
,
input_tensors
=
None
):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param infs: a list of `Inferencer` instance.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
Args:
ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input_tensor_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
...
...
@@ -78,7 +103,7 @@ class InferenceRunner(Callback):
dispatcer
.
add_entry
(
inf
.
get_output_tensors
())
all_names
=
dispatcer
.
get_all_names
()
IOTensor
=
InferenceRunner
.
IOTensor
IOTensor
=
InferenceRunner
.
_
IOTensor
self
.
output_tensors
=
list
(
filter
(
lambda
x
:
x
not
in
self
.
input_tensors
,
all_names
))
...
...
tensorpack/callbacks/param.py
View file @
b5f8c73a
...
...
@@ -13,40 +13,59 @@ from .base import Callback
from
..utils
import
logger
from
..tfutils
import
get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
__all__
=
[
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
,
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
]
@
six
.
add_metaclass
(
ABCMeta
)
class
HyperParam
(
object
):
""" Base class for a hyper
param
"""
""" Base class for a hyper
param.
"""
def
setup_graph
(
self
):
""" setup the graph in `
setup_graph
` callback stage, if necessary"""
""" setup the graph in `
`setup_graph`
` callback stage, if necessary"""
pass
@
abstractmethod
def
set_value
(
self
,
v
):
""" define how the value of the param will be set"""
"""
Set the value of the param.
Args:
v: the value to be set
"""
pass
@
abstractmethod
def
get_value
(
self
):
"""
Get the value of the param.
"""
pass
@
property
def
readable_name
(
self
):
""" A name to display"""
""" A name to display
"""
return
self
.
_readable_name
class
GraphVarParam
(
HyperParam
):
"""
a variable in the graph
can be a hyperparam"""
"""
A variable in the graph (e.g. learning_rate)
can be a hyperparam"""
def
__init__
(
self
,
name
,
shape
=
[]):
"""
Args:
name(str): name of the variable.
shape(list): shape of the variable.
"""
self
.
name
=
name
self
.
shape
=
shape
self
.
_readable_name
,
self
.
var_name
=
get_op_var_name
(
name
)
def
setup_graph
(
self
):
""" Will setup the assign operator for that variable. """
all_vars
=
tf
.
global_variables
()
for
v
in
all_vars
:
if
v
.
name
==
self
.
var_name
:
...
...
@@ -60,17 +79,24 @@ class GraphVarParam(HyperParam):
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
set_value
(
self
,
v
):
""" Assign the variable a new value. """
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
def
get_value
(
self
):
""" Evaluate the variable. """
return
self
.
var
.
eval
()
class
ObjAttrParam
(
HyperParam
):
"""
an attribute of an object can be a hyperparam
"""
"""
An attribute of an object can be a hyperparam.
"""
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
""" :param readable_name: default to be attrname."""
"""
Args:
obj: the object
attrname (str): the attribute
readable_name(str): The name to display. Defaults to be ``attrname``.
"""
self
.
obj
=
obj
self
.
attrname
=
attrname
if
readable_name
is
None
:
...
...
@@ -87,12 +113,14 @@ class ObjAttrParam(HyperParam):
class
HyperParamSetter
(
Callback
):
"""
Base class to set hyperparameters after
every epoch.
An abstract base callback to set hyperparameters in
every epoch.
"""
def
__init__
(
self
,
param
):
"""
:param param: a `HyperParam` instance, or a string (assumed to be a scalar `GraphVarParam`)
Args:
param(HyperParam or str): if is a :class:`str`, it is assumed to
be a :class:`GraphVarParam`.
"""
# if a string, assumed to be a scalar graph variable
if
isinstance
(
param
,
six
.
string_types
):
...
...
@@ -106,7 +134,13 @@ class HyperParamSetter(Callback):
def
get_value_to_set
(
self
):
"""
:returns: the value to assign to the variable now.
Returns:
The value to assign to the variable.
Note:
Subclasses will implemenet the abstract method
:meth:`_get_value_to_set`, which should return a new value to
set, or return None to do nothing.
"""
ret
=
self
.
_get_value_to_set
()
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
...
...
@@ -115,13 +149,17 @@ class HyperParamSetter(Callback):
self
.
last_value
=
ret
return
ret
def
get_current_value
(
self
):
return
self
.
param
.
get_value
()
@
abstractmethod
def
_get_value_to_set
(
self
):
pass
def
get_current_value
(
self
):
"""
Returns:
The current value of the param.
"""
return
self
.
param
.
get_value
()
def
_trigger_epoch
(
self
):
self
.
_set_param
()
...
...
@@ -136,14 +174,19 @@ class HyperParamSetter(Callback):
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters by loading the value from a file each time it get called.
Set hyperparameter by loading the value from a file each time it get called.
This is useful for manually tuning some parameters (e.g. learning_rate)
without interrupting the training.
"""
def
__init__
(
self
,
param
,
file_name
=
'hyper.txt'
):
"""
:param file_name: a file containing the value of the variable.
Args:
param: same as in :class:`HyperParamSetter`.
file_name(str): a file containing the value of the variable.
Each line in the file is a k:v pair, where k is
param.readable_name, and v is the value
param.readable_name, and v is the value. If the pair is not found,
the param will not be changed.
"""
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
self
.
file_name
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
file_name
)
...
...
@@ -170,15 +213,25 @@ class HumanHyperParamSetter(HyperParamSetter):
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters by a predefined schedule.
Set hyperparameters by a predefined
epoch-based
schedule.
"""
def
__init__
(
self
,
param
,
schedule
,
interp
=
None
):
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
(ep, val) means set the param to "val" after the `ep`th epoch.
If epoch == 0, the value is set before training.
:param interp: None: no interpolation. 'linear': linear interpolation
Args:
param: same as in :class:`HyperParamSetter`.
schedule(list): with the format ``[(epoch1, val1), (epoch2, val2),
(epoch3, val3), ...]``.
Each ``(ep, val)`` pair means to set the param
to "val" after the `ep`th epoch.
If ep == 0, the value will be set before training.
interp: None: no interpolation. 'linear': linear interpolation
Example:
.. code-block:: python
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
"""
schedule
=
[(
int
(
a
),
float
(
b
))
for
a
,
b
in
schedule
]
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
...
...
@@ -209,10 +262,20 @@ class ScheduledHyperParamSetter(HyperParamSetter):
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
""" Set the parameter by a function of epoch num and old value. """
def
__init__
(
self
,
param
,
func
):
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
"""
Args:
param: same as in :class:`HyperParamSetter`.
func: ``param`` will be set by ``new_value = func(epoch_num, old_value)``.
Example:
Decrease by a factor of 0.9 every two epochs:
.. code-block:: python
HyperParamSetterWithFunc('learning_rate',
lambda e, x: x * 0.9 if e
% 2
== 0 else x)
"""
super
(
HyperParamSetterWithFunc
,
self
)
.
__init__
(
param
)
self
.
f
=
func
...
...
@@ -222,22 +285,32 @@ class HyperParamSetterWithFunc(HyperParamSetter):
class
StatMonitorParamSetter
(
HyperParamSetter
):
"""
Change the param by monitoring the change of a statistic.
Change when it wasn't decreasing/increasing enough.
"""
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
last_k
,
reverse
=
False
):
last_k
,
reverse
=
False
):
"""
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs.
Change param by `new_value = value_func(old_value)`,
if :
min(stats) >= stats[0] - threshold, where
stats = [`stat_nam` in latest `last_k` epochs]
Args:
param: same as in :class:`HyperParamSetter`.
stat_name (str): name of the statistics.
value_func (float -> float): a function which returns a new value
taking the old value.
threshold (float): change threshold.
last_k (int): last k epochs.
reverse (bool): monitor increasing instead of decreasing.
This callback will change param by ``new_value = value_func(old_value)``, when:
``min(stats) >= stats[0] - threshold``, where
``stats = [stat_name in last k epochs]``
Example:
If validation error wasn't decreasing for 5 epochs, anneal the learning rate:
For example, if error wasn't decreasing, anneal the learning rate:
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
.. code-block:: python
If reverse==True, use 'increasing' instead of decreasing
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2, 0, 5)
"""
super
(
StatMonitorParamSetter
,
self
)
.
__init__
(
param
)
self
.
stat_name
=
stat_name
...
...
tensorpack/callbacks/saver.py
View file @
b5f8c73a
...
...
@@ -16,22 +16,19 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class
ModelSaver
(
Callback
):
"""
Save the model to
logger directory
.
Save the model to
``logger.LOG_DIR`` directory every epoch
.
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
None
):
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
):
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
Args:
keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation.
var_collections (str or list): the variable collection (or list of collections) o save.
"""
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
if
var_collections
is
None
:
try
:
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
except
:
var_collections
=
tf
.
GraphKeys
.
VARIABLES
if
not
isinstance
(
var_collections
,
list
):
var_collections
=
[
var_collections
]
self
.
var_collections
=
var_collections
...
...
@@ -87,8 +84,25 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
class
MinSaver
(
Callback
):
"""
Separately save the model with minimum value of some statistics.
"""
def
__init__
(
self
,
monitor_stat
,
reverse
=
False
,
filename
=
None
):
"""
Args:
monitor_stat(str): the name of the statistics.
reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``.
Example:
Save the model with minimum validation error to
"min-val-error.tfmodel" under ``logger.LOG_DIR``:
.. code-block:: python
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
,
filename
=
None
):
MinSaver('val-error')
"""
self
.
monitor_stat
=
monitor_stat
self
.
reverse
=
reverse
self
.
filename
=
filename
...
...
@@ -128,6 +142,14 @@ class MinSaver(Callback):
class
MaxSaver
(
MinSaver
):
def
__init__
(
self
,
monitor_stat
):
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
)
"""
Separately save the model with maximum value of some statistics.
"""
def
__init__
(
self
,
monitor_stat
,
filename
=
None
):
"""
Args:
monitor_stat(str): the name of the statistics.
filename (str): the name for the saved model.
Defaults to ``max-{monitor_stat}.tfmodel``.
"""
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
,
filename
=
filename
)
tensorpack/callbacks/stats.py
View file @
b5f8c73a
...
...
@@ -20,7 +20,8 @@ class StatHolder(object):
def
__init__
(
self
,
log_dir
):
"""
:param log_dir: directory to save the stats.
Args:
log_dir(str): directory to save the stats.
"""
self
.
set_print_tag
([])
self
.
blacklist_tag
=
set
()
...
...
@@ -38,19 +39,24 @@ class StatHolder(object):
def
add_stat
(
self
,
k
,
v
):
"""
Add a stat.
:param k: name
:param v: value
"""
self
.
stat_now
[
k
]
=
float
(
v
)
def
set_print_tag
(
self
,
print_tag
):
"""
Set name of stats to print.
Args:
print_tag: a collection of string.
"""
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
def
add_blacklist_tag
(
self
,
blacklist_tag
):
""" Disable printing for some tags """
""" Disable printing for some tags
Args:
blacklist_tag: a collection of string.
"""
self
.
blacklist_tag
|=
set
(
blacklist_tag
)
def
get_stat_now
(
self
,
key
):
...
...
@@ -60,6 +66,10 @@ class StatHolder(object):
return
self
.
stat_now
[
key
]
def
get_stat_history
(
self
,
key
):
"""
Returns:
list: all history of a stat.
"""
ret
=
[]
for
h
in
self
.
stat_history
:
v
=
h
.
get
(
key
,
None
)
...
...
@@ -97,13 +107,14 @@ class StatHolder(object):
class
StatPrinter
(
Callback
):
"""
Control what stats to prin
t.
A callback to control what stats to print. Print everything by defaul
t.
"""
def
__init__
(
self
,
print_tag
=
None
):
"""
:param print_tag: a list of regex to match scalar summary to print.
If None, will print all scalar tags
Args:
print_tag: a list of stat names to print.
If None, will print all scalar tags.
"""
self
.
print_tag
=
print_tag
...
...
@@ -125,15 +136,25 @@ class StatPrinter(Callback):
class
SendStat
(
Callback
):
"""
Execute a command with some specific stats.
For example, send the stats to your phone through pushbullet:
This is useful for, e.g. building a custom statistics monitor.
"""
def
__init__
(
self
,
command
,
stats
):
"""
Args:
command(str): a command to execute. Use format string with stat
names as keys.
stats(list or str): stat name(s) to use.
Example:
Send the stats to your phone through pushbullet:
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes
\
-d type=note -d title="validation error"
\
.. code-block:: python
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes
\\
-d type=note -d title="validation error"
\\
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
def
__init__
(
self
,
command
,
stats
):
self
.
command
=
command
if
not
isinstance
(
stats
,
list
):
stats
=
[
stats
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment