Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b5f8c73a
Commit
b5f8c73a
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
sphinx doc for callbacks
parent
b5acbf3a
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
312 additions
and
163 deletions
+312
-163
examples/DoReFa-Net/README.md
examples/DoReFa-Net/README.md
+1
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+20
-6
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+10
-6
tensorpack/callbacks/dispatcher.py
tensorpack/callbacks/dispatcher.py
+0
-31
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+11
-10
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+10
-5
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+4
-2
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+41
-28
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+33
-8
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+112
-39
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+35
-13
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+35
-14
No files found.
examples/DoReFa-Net/README.md
View file @
b5f8c73a
...
@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
...
@@ -9,7 +9,7 @@ In this repo, bit operations are performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[
google drive
](
https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
)
.
[
google drive
](
https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
)
.
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The
binary-weight 4-bit-activation ResNet-18
model has 59.2% top-1 validation error.
The
__binary-weight 4-bit-activation ResNet-18__
model has 59.2% top-1 validation error.
Alternative link to this page:
[
http://dorefa.net
](
http://dorefa.net
)
Alternative link to this page:
[
http://dorefa.net
](
http://dorefa.net
)
...
...
tensorpack/callbacks/base.py
View file @
b5f8c73a
...
@@ -27,7 +27,8 @@ class Callback(object):
...
@@ -27,7 +27,8 @@ class Callback(object):
Called before finalizing the graph.
Called before finalizing the graph.
Use this callback to setup some ops used in the callback.
Use this callback to setup some ops used in the callback.
:param trainer: :class:`train.Trainer` instance
Args:
trainer(Trainer): the trainer which calls the callback
"""
"""
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
graph
=
tf
.
get_default_graph
()
...
@@ -59,7 +60,7 @@ class Callback(object):
...
@@ -59,7 +60,7 @@ class Callback(object):
"""
"""
Triggered after every epoch.
Triggered after every epoch.
In this function,
self.epoch_num
would be the number of epoch finished.
In this function,
``self.epoch_num``
would be the number of epoch finished.
"""
"""
self
.
epoch_num
+=
1
self
.
epoch_num
+=
1
self
.
_trigger_epoch
()
self
.
_trigger_epoch
()
...
@@ -72,8 +73,15 @@ class Callback(object):
...
@@ -72,8 +73,15 @@ class Callback(object):
class
ProxyCallback
(
Callback
):
class
ProxyCallback
(
Callback
):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
"""
def
__init__
(
self
,
cb
):
def
__init__
(
self
,
cb
):
"""
Args:
cb(Callback): the underlying callback
"""
self
.
cb
=
cb
self
.
cb
=
cb
def
_before_train
(
self
):
def
_before_train
(
self
):
...
@@ -94,14 +102,20 @@ class ProxyCallback(Callback):
...
@@ -94,14 +102,20 @@ class ProxyCallback(Callback):
class
PeriodicCallback
(
ProxyCallback
):
class
PeriodicCallback
(
ProxyCallback
):
"""
"""
A callback to be triggered after every `period` epochs.
Wrap a callback so that it is triggered after every ``period`` epochs.
Doesn't work for trigger_step
Doesn't work for ``trigger_step``.
"""
"""
def
__init__
(
self
,
cb
,
period
):
def
__init__
(
self
,
cb
,
period
):
"""
"""
:param cb: a `Callback`
Args:
:param period: int
cb(Callback): the callback to be triggered periodically
period(int): the period
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
"""
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
self
.
period
=
int
(
period
)
self
.
period
=
int
(
period
)
...
...
tensorpack/callbacks/concurrency.py
View file @
b5f8c73a
...
@@ -11,15 +11,19 @@ __all__ = ['StartProcOrThread']
...
@@ -11,15 +11,19 @@ __all__ = ['StartProcOrThread']
class
StartProcOrThread
(
Callback
):
class
StartProcOrThread
(
Callback
):
"""
Start some threads or processes before training.
"""
def
__init__
(
self
,
procs_threads
):
def
__init__
(
self
,
startable
):
"""
"""
Start extra threads and processes before training
Args:
:param procs_threads: list of processes or threads
startable(list): list of processes or threads which have ``start()`` method.
Can also be a single instance of process of thread.
"""
"""
if
not
isinstance
(
procs_threads
,
list
):
if
not
isinstance
(
startable
,
list
):
procs_threads
=
[
procs_threads
]
startable
=
[
startable
]
self
.
_procs_threads
=
procs_threads
self
.
_procs_threads
=
startable
def
_before_train
(
self
):
def
_before_train
(
self
):
logger
.
info
(
"Starting "
+
logger
.
info
(
"Starting "
+
...
...
tensorpack/callbacks/dispatcher.py
deleted
100644 → 0
View file @
b5acbf3a
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dispatcher.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..tfutils.common
import
get_op_tensor_name
__all__
=
[
'OutputTensorDispatcer'
]
class
OutputTensorDispatcer
(
object
):
def
__init__
(
self
):
self
.
_names
=
[]
self
.
_idxs
=
[]
def
add_entry
(
self
,
names
):
v
=
[]
for
n
in
names
:
tensorname
=
get_op_tensor_name
(
n
)[
1
]
if
tensorname
in
self
.
_names
:
v
.
append
(
self
.
_names
.
index
(
tensorname
))
else
:
self
.
_names
.
append
(
tensorname
)
v
.
append
(
len
(
self
.
_names
)
-
1
)
self
.
_idxs
.
append
(
v
)
def
get_all_names
(
self
):
return
self
.
_names
def
get_idx_for_each_entry
(
self
):
return
self
.
_idxs
tensorpack/callbacks/dump.py
View file @
b5f8c73a
...
@@ -8,26 +8,27 @@ import numpy as np
...
@@ -8,26 +8,27 @@ import numpy as np
from
.base
import
Callback
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_op_
va
r_name
from
..tfutils
import
get_op_
tenso
r_name
__all__
=
[
'DumpParamAsImage'
]
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Callback
):
class
DumpParamAsImage
(
Callback
):
"""
"""
Dump a variable to image(s)
after every epoch to logger.LOG_DIR
.
Dump a variable to image(s)
to ``logger.LOG_DIR`` after every epoch
.
"""
"""
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
"""
"""
:param var_name: the name of the variable.
Args:
:param prefix: the filename prefix for saved images. Default is the op name.
var_name (str): the name of the variable.
:param map_func: map the value of the variable to an image or list of
prefix (str): the filename prefix for saved images. Defaults to the Op name.
images of shape [h, w] or [h, w, c]. If None, will use identity
map_func: map the value of the variable to an image or list of
:param scale: a multiplier on pixel values, applied after map_func. default to 255
images of shape [h, w] or [h, w, c]. If None, will use identity.
:param clip: whether to clip the result to [0, 255]
scale (float): a multiplier on pixel values, applied after map_func.
clip (bool): whether to clip the result to [0, 255].
"""
"""
op_name
,
self
.
var_name
=
get_op_
va
r_name
(
var_name
)
op_name
,
self
.
var_name
=
get_op_
tenso
r_name
(
var_name
)
self
.
func
=
map_func
self
.
func
=
map_func
if
prefix
is
None
:
if
prefix
is
None
:
self
.
prefix
=
op_name
self
.
prefix
=
op_name
...
@@ -45,7 +46,7 @@ class DumpParamAsImage(Callback):
...
@@ -45,7 +46,7 @@ class DumpParamAsImage(Callback):
val
=
self
.
trainer
.
sess
.
run
(
self
.
var
)
val
=
self
.
trainer
.
sess
.
run
(
self
.
var
)
if
self
.
func
is
not
None
:
if
self
.
func
is
not
None
:
val
=
self
.
func
(
val
)
val
=
self
.
func
(
val
)
if
isinstance
(
val
,
list
):
if
isinstance
(
val
,
list
)
or
val
.
ndim
==
4
:
for
idx
,
im
in
enumerate
(
val
):
for
idx
,
im
in
enumerate
(
val
):
self
.
_dump_image
(
im
,
idx
)
self
.
_dump_image
(
im
,
idx
)
else
:
else
:
...
...
tensorpack/callbacks/graph.py
View file @
b5f8c73a
...
@@ -11,13 +11,19 @@ __all__ = ['RunOp']
...
@@ -11,13 +11,19 @@ __all__ = ['RunOp']
class
RunOp
(
Callback
):
class
RunOp
(
Callback
):
""" Run an
op periodically
"""
""" Run an
Op.
"""
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
"""
"""
:param setup_func: a function that returns the op in the graph
Args:
:param run_before: run the op before training
setup_func: a function that returns the Op in the graph
:param run_epoch: run the op on every epoch trigger
run_before (bool): run the Op before training
run_epoch (bool): run the Op on every epoch trigger
Examples:
The `DQN Example
<https://github.com/ppwwyyxx/tensorpack/blob/master/examples/Atari2600/DQN.py#L182>`_
uses this callback to update target network.
"""
"""
self
.
setup_func
=
setup_func
self
.
setup_func
=
setup_func
self
.
run_before
=
run_before
self
.
run_before
=
run_before
...
@@ -25,7 +31,6 @@ class RunOp(Callback):
...
@@ -25,7 +31,6 @@ class RunOp(Callback):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_op
=
self
.
setup_func
()
self
.
_op
=
self
.
setup_func
()
# self._op_name = self._op.name
def
_before_train
(
self
):
def
_before_train
(
self
):
if
self
.
run_before
:
if
self
.
run_before
:
...
...
tensorpack/callbacks/group.py
View file @
b5f8c73a
...
@@ -44,12 +44,14 @@ class CallbackTimeLogger(object):
...
@@ -44,12 +44,14 @@ class CallbackTimeLogger(object):
class
Callbacks
(
Callback
):
class
Callbacks
(
Callback
):
"""
"""
A container to hold all callbacks, and execute them in the right order and proper session.
A container to hold all callbacks, and execute them in the right order
(e.g. :class:`StatPrinter` will be executed at last).
"""
"""
def
__init__
(
self
,
cbs
):
def
__init__
(
self
,
cbs
):
"""
"""
:param cbs: a list of `Callbacks`
Args:
cbs(list): a list of :class:`Callback` instances.
"""
"""
# check type
# check type
for
cb
in
cbs
:
for
cb
in
cbs
:
...
...
tensorpack/callbacks/inference.py
View file @
b5f8c73a
...
@@ -12,12 +12,13 @@ from ..utils import logger
...
@@ -12,12 +12,13 @@ from ..utils import logger
from
..utils.stats
import
RatioCounter
,
BinaryStatistics
from
..utils.stats
import
RatioCounter
,
BinaryStatistics
from
..tfutils
import
get_op_var_name
from
..tfutils
import
get_op_var_name
__all__
=
[
'
ClassificationErro
r'
,
__all__
=
[
'
ScalarStats'
,
'Inference
r'
,
'
ScalarStats'
,
'Inference
r'
,
'BinaryClassificationStats'
]
'
ClassificationErro
r'
,
'BinaryClassificationStats'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Inferencer
(
object
):
class
Inferencer
(
object
):
""" Base class of Inferencer. To be used with :class:`InferenceRunner`. """
def
before_inference
(
self
):
def
before_inference
(
self
):
"""
"""
...
@@ -30,7 +31,11 @@ class Inferencer(object):
...
@@ -30,7 +31,11 @@ class Inferencer(object):
def
datapoint
(
self
,
output
):
def
datapoint
(
self
,
output
):
"""
"""
Called after complete running every data point
Called after each new datapoint finished the forward inference.
Args:
output(list): list of output this inferencer needs. Has the same
length as ``self.get_output_tensors()``.
"""
"""
self
.
_datapoint
(
output
)
self
.
_datapoint
(
output
)
...
@@ -41,8 +46,8 @@ class Inferencer(object):
...
@@ -41,8 +46,8 @@ class Inferencer(object):
def
after_inference
(
self
):
def
after_inference
(
self
):
"""
"""
Called after a round of inference ends.
Called after a round of inference ends.
Returns a dict of statistics which will be logged by the
InferenceRunner
.
Returns a dict of statistics which will be logged by the
:class:`InferenceRunner`
.
The inferencer needs to handle other
kind of logging by their own
.
The inferencer needs to handle other
type of logging by itself, if there is any
.
"""
"""
return
self
.
_after_inference
()
return
self
.
_after_inference
()
...
@@ -51,7 +56,7 @@ class Inferencer(object):
...
@@ -51,7 +56,7 @@ class Inferencer(object):
def
get_output_tensors
(
self
):
def
get_output_tensors
(
self
):
"""
"""
Return a list of tensor names
needed for this inference
Return a list of tensor names
this inferencer needed.
"""
"""
return
self
.
_get_output_tensors
()
return
self
.
_get_output_tensors
()
...
@@ -62,15 +67,16 @@ class Inferencer(object):
...
@@ -62,15 +67,16 @@ class Inferencer(object):
class
ScalarStats
(
Inferencer
):
class
ScalarStats
(
Inferencer
):
"""
"""
Write some scalar tensor to both stat and summary.
Statistics of some scalar tensor.
The output of the given Ops must be a scalar.
The value will be averaged over all given datapoints.
The value will be averaged over all data points in the inference dataflow.
"""
"""
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
"""
"""
:param names_to_print: list of names of tensors, or just a name
Args:
:param prefix: an optional prefix for logging
names_to_print(list or str): list of names or just one name. The
corresponding tensors have to be scalar.
prefix(str): a prefix for logging
"""
"""
if
not
isinstance
(
names_to_print
,
list
):
if
not
isinstance
(
names_to_print
,
list
):
self
.
names
=
[
names_to_print
]
self
.
names
=
[
names_to_print
]
...
@@ -85,6 +91,8 @@ class ScalarStats(Inferencer):
...
@@ -85,6 +91,8 @@ class ScalarStats(Inferencer):
self
.
stats
=
[]
self
.
stats
=
[]
def
_datapoint
(
self
,
output
):
def
_datapoint
(
self
,
output
):
for
o
in
output
:
assert
isinstance
(
o
,
(
float
,
np
.
float32
)),
type
(
o
)
self
.
stats
.
append
(
output
)
self
.
stats
.
append
(
output
)
def
_after_inference
(
self
):
def
_after_inference
(
self
):
...
@@ -101,24 +109,27 @@ class ScalarStats(Inferencer):
...
@@ -101,24 +109,27 @@ class ScalarStats(Inferencer):
class
ClassificationError
(
Inferencer
):
class
ClassificationError
(
Inferencer
):
"""
"""
Compute classification error in batch mode, from a `
wrong` variable
Compute classification error in batch mode, from a `
`wrong`` tensor.
The `
wrong` tensor is supposed to be an 0/1 integer
vector containing
The `
`wrong`` tensor is supposed to be an binary
vector containing
whether each sample in the batch is
incorrectly
classified.
whether each sample in the batch is
*incorrectly*
classified.
You can use `
tf.nn.in_top_k` to produce this vector record top-k error as well
.
You can use `
`tf.nn.in_top_k`` to produce this vector
.
This
callback produce
the "true" error,
This
Inferencer produces
the "true" error,
taking account of the fact that batches might not have the same size in
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
testing (because the size of test set might not be a multiple of batch size).
Therefore the result
is
different from averaging the error rate of each batch.
Therefore the result
can be
different from averaging the error rate of each batch.
"""
"""
def
__init__
(
self
,
wrong_
va
r_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
def
__init__
(
self
,
wrong_
tenso
r_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
"""
"""
:param wrong_var_name: name of the `wrong` variable
Args:
:param summary_name: the name for logging
wrong_tensor_name(str): name of the ``wrong`` tensor.
The default is the same as the default output name of
:meth:`prediction_incorrect`.
summary_name(str): the name for logging.
"""
"""
self
.
wrong_var_name
=
wrong_
va
r_name
self
.
wrong_var_name
=
wrong_
tenso
r_name
self
.
summary_name
=
summary_name
self
.
summary_name
=
summary_name
def
_get_output_tensors
(
self
):
def
_get_output_tensors
(
self
):
...
@@ -144,21 +155,23 @@ class ClassificationError(Inferencer):
...
@@ -144,21 +155,23 @@ class ClassificationError(Inferencer):
class
BinaryClassificationStats
(
Inferencer
):
class
BinaryClassificationStats
(
Inferencer
):
""" Compute precision/recall in binary classification, given the
"""
Compute precision / recall in binary classification, given the
prediction vector and the label vector.
prediction vector and the label vector.
"""
"""
def
__init__
(
self
,
pred_
var_name
,
label_va
r_name
,
summary_prefix
=
'val'
):
def
__init__
(
self
,
pred_
tensor_name
,
label_tenso
r_name
,
summary_prefix
=
'val'
):
"""
"""
:param pred_var_name: name of the 0/1 prediction tensor.
Args:
:param label_var_name: name of the 0/1 label tensor.
pred_tensor_name(str): name of the 0/1 prediction tensor.
label_tensor_name(str): name of the 0/1 label tensor.
"""
"""
self
.
pred_
var_name
=
pred_va
r_name
self
.
pred_
tensor_name
=
pred_tenso
r_name
self
.
label_
var_name
=
label_va
r_name
self
.
label_
tensor_name
=
label_tenso
r_name
self
.
prefix
=
summary_prefix
self
.
prefix
=
summary_prefix
def
_get_output_tensors
(
self
):
def
_get_output_tensors
(
self
):
return
[
self
.
pred_
var_name
,
self
.
label_va
r_name
]
return
[
self
.
pred_
tensor_name
,
self
.
label_tenso
r_name
]
def
_before_inference
(
self
):
def
_before_inference
(
self
):
self
.
stat
=
BinaryStatistics
()
self
.
stat
=
BinaryStatistics
()
...
...
tensorpack/callbacks/inference_runner.py
View file @
b5f8c73a
...
@@ -11,13 +11,36 @@ from six.moves import zip, range
...
@@ -11,13 +11,36 @@ from six.moves import zip, range
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
.base
import
Callback
from
.base
import
Callback
from
.inference
import
Inferencer
from
.inference
import
Inferencer
from
.dispatcher
import
OutputTensorDispatcer
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
from
..tfutils.common
import
get_op_tensor_name
from
..train.input_data
import
FeedfreeInput
from
..train.input_data
import
FeedfreeInput
__all__
=
[
'InferenceRunner'
]
__all__
=
[
'InferenceRunner'
]
class
OutputTensorDispatcer
(
object
):
def
__init__
(
self
):
self
.
_names
=
[]
self
.
_idxs
=
[]
def
add_entry
(
self
,
names
):
v
=
[]
for
n
in
names
:
tensorname
=
get_op_tensor_name
(
n
)[
1
]
if
tensorname
in
self
.
_names
:
v
.
append
(
self
.
_names
.
index
(
tensorname
))
else
:
self
.
_names
.
append
(
tensorname
)
v
.
append
(
len
(
self
.
_names
)
-
1
)
self
.
_idxs
.
append
(
v
)
def
get_all_names
(
self
):
return
self
.
_names
def
get_idx_for_each_entry
(
self
):
return
self
.
_idxs
def
summary_inferencer
(
trainer
,
infs
):
def
summary_inferencer
(
trainer
,
infs
):
for
inf
in
infs
:
for
inf
in
infs
:
ret
=
inf
.
after_inference
()
ret
=
inf
.
after_inference
()
...
@@ -32,17 +55,19 @@ def summary_inferencer(trainer, infs):
...
@@ -32,17 +55,19 @@ def summary_inferencer(trainer, infs):
class
InferenceRunner
(
Callback
):
class
InferenceRunner
(
Callback
):
"""
"""
A callback that runs different kinds of inferencer.
A callback that runs a list of :class:`Inferencer` on some
:class:`DataFlow`.
"""
"""
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
_
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
def
__init__
(
self
,
ds
,
infs
,
input_tensors
=
None
):
def
__init__
(
self
,
ds
,
infs
,
input_tensors
=
None
):
"""
"""
:param ds: inference dataset. a `DataFlow` instance.
Args:
:param infs: a list of `Inferencer` instance.
ds (DataFlow): the DataFlow to run inferencer on.
:param input_tensor_names: list of tensors to feed the dataflow to.
infs (list): a list of `Inferencer` instances.
default to all the input placeholders.
input_tensor_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
...
@@ -78,7 +103,7 @@ class InferenceRunner(Callback):
...
@@ -78,7 +103,7 @@ class InferenceRunner(Callback):
dispatcer
.
add_entry
(
inf
.
get_output_tensors
())
dispatcer
.
add_entry
(
inf
.
get_output_tensors
())
all_names
=
dispatcer
.
get_all_names
()
all_names
=
dispatcer
.
get_all_names
()
IOTensor
=
InferenceRunner
.
IOTensor
IOTensor
=
InferenceRunner
.
_
IOTensor
self
.
output_tensors
=
list
(
filter
(
self
.
output_tensors
=
list
(
filter
(
lambda
x
:
x
not
in
self
.
input_tensors
,
all_names
))
lambda
x
:
x
not
in
self
.
input_tensors
,
all_names
))
...
...
tensorpack/callbacks/param.py
View file @
b5f8c73a
...
@@ -13,40 +13,59 @@ from .base import Callback
...
@@ -13,40 +13,59 @@ from .base import Callback
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_op_var_name
from
..tfutils
import
get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
__all__
=
[
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
,
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
HyperParam
(
object
):
class
HyperParam
(
object
):
""" Base class for a hyper
param
"""
""" Base class for a hyper
param.
"""
def
setup_graph
(
self
):
def
setup_graph
(
self
):
""" setup the graph in `
setup_graph
` callback stage, if necessary"""
""" setup the graph in `
`setup_graph`
` callback stage, if necessary"""
pass
pass
@
abstractmethod
@
abstractmethod
def
set_value
(
self
,
v
):
def
set_value
(
self
,
v
):
""" define how the value of the param will be set"""
"""
Set the value of the param.
Args:
v: the value to be set
"""
pass
@
abstractmethod
def
get_value
(
self
):
"""
Get the value of the param.
"""
pass
pass
@
property
@
property
def
readable_name
(
self
):
def
readable_name
(
self
):
""" A name to display"""
""" A name to display
"""
return
self
.
_readable_name
return
self
.
_readable_name
class
GraphVarParam
(
HyperParam
):
class
GraphVarParam
(
HyperParam
):
"""
a variable in the graph
can be a hyperparam"""
"""
A variable in the graph (e.g. learning_rate)
can be a hyperparam"""
def
__init__
(
self
,
name
,
shape
=
[]):
def
__init__
(
self
,
name
,
shape
=
[]):
"""
Args:
name(str): name of the variable.
shape(list): shape of the variable.
"""
self
.
name
=
name
self
.
name
=
name
self
.
shape
=
shape
self
.
shape
=
shape
self
.
_readable_name
,
self
.
var_name
=
get_op_var_name
(
name
)
self
.
_readable_name
,
self
.
var_name
=
get_op_var_name
(
name
)
def
setup_graph
(
self
):
def
setup_graph
(
self
):
""" Will setup the assign operator for that variable. """
all_vars
=
tf
.
global_variables
()
all_vars
=
tf
.
global_variables
()
for
v
in
all_vars
:
for
v
in
all_vars
:
if
v
.
name
==
self
.
var_name
:
if
v
.
name
==
self
.
var_name
:
...
@@ -60,17 +79,24 @@ class GraphVarParam(HyperParam):
...
@@ -60,17 +79,24 @@ class GraphVarParam(HyperParam):
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
set_value
(
self
,
v
):
def
set_value
(
self
,
v
):
""" Assign the variable a new value. """
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
def
get_value
(
self
):
def
get_value
(
self
):
""" Evaluate the variable. """
return
self
.
var
.
eval
()
return
self
.
var
.
eval
()
class
ObjAttrParam
(
HyperParam
):
class
ObjAttrParam
(
HyperParam
):
"""
an attribute of an object can be a hyperparam
"""
"""
An attribute of an object can be a hyperparam.
"""
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
""" :param readable_name: default to be attrname."""
"""
Args:
obj: the object
attrname (str): the attribute
readable_name(str): The name to display. Defaults to be ``attrname``.
"""
self
.
obj
=
obj
self
.
obj
=
obj
self
.
attrname
=
attrname
self
.
attrname
=
attrname
if
readable_name
is
None
:
if
readable_name
is
None
:
...
@@ -87,12 +113,14 @@ class ObjAttrParam(HyperParam):
...
@@ -87,12 +113,14 @@ class ObjAttrParam(HyperParam):
class
HyperParamSetter
(
Callback
):
class
HyperParamSetter
(
Callback
):
"""
"""
Base class to set hyperparameters after
every epoch.
An abstract base callback to set hyperparameters in
every epoch.
"""
"""
def
__init__
(
self
,
param
):
def
__init__
(
self
,
param
):
"""
"""
:param param: a `HyperParam` instance, or a string (assumed to be a scalar `GraphVarParam`)
Args:
param(HyperParam or str): if is a :class:`str`, it is assumed to
be a :class:`GraphVarParam`.
"""
"""
# if a string, assumed to be a scalar graph variable
# if a string, assumed to be a scalar graph variable
if
isinstance
(
param
,
six
.
string_types
):
if
isinstance
(
param
,
six
.
string_types
):
...
@@ -106,7 +134,13 @@ class HyperParamSetter(Callback):
...
@@ -106,7 +134,13 @@ class HyperParamSetter(Callback):
def
get_value_to_set
(
self
):
def
get_value_to_set
(
self
):
"""
"""
:returns: the value to assign to the variable now.
Returns:
The value to assign to the variable.
Note:
Subclasses will implemenet the abstract method
:meth:`_get_value_to_set`, which should return a new value to
set, or return None to do nothing.
"""
"""
ret
=
self
.
_get_value_to_set
()
ret
=
self
.
_get_value_to_set
()
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
...
@@ -115,13 +149,17 @@ class HyperParamSetter(Callback):
...
@@ -115,13 +149,17 @@ class HyperParamSetter(Callback):
self
.
last_value
=
ret
self
.
last_value
=
ret
return
ret
return
ret
def
get_current_value
(
self
):
return
self
.
param
.
get_value
()
@
abstractmethod
@
abstractmethod
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
pass
pass
def
get_current_value
(
self
):
"""
Returns:
The current value of the param.
"""
return
self
.
param
.
get_value
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
_set_param
()
self
.
_set_param
()
...
@@ -136,14 +174,19 @@ class HyperParamSetter(Callback):
...
@@ -136,14 +174,19 @@ class HyperParamSetter(Callback):
class
HumanHyperParamSetter
(
HyperParamSetter
):
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
"""
Set hyperparameters by loading the value from a file each time it get called.
Set hyperparameter by loading the value from a file each time it get called.
This is useful for manually tuning some parameters (e.g. learning_rate)
without interrupting the training.
"""
"""
def
__init__
(
self
,
param
,
file_name
=
'hyper.txt'
):
def
__init__
(
self
,
param
,
file_name
=
'hyper.txt'
):
"""
"""
:param file_name: a file containing the value of the variable.
Args:
Each line in the file is a k:v pair, where k is
param: same as in :class:`HyperParamSetter`.
param.readable_name, and v is the value
file_name(str): a file containing the value of the variable.
Each line in the file is a k:v pair, where k is
param.readable_name, and v is the value. If the pair is not found,
the param will not be changed.
"""
"""
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
self
.
file_name
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
file_name
)
self
.
file_name
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
file_name
)
...
@@ -170,15 +213,25 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -170,15 +213,25 @@ class HumanHyperParamSetter(HyperParamSetter):
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
"""
"""
Set hyperparameters by a predefined schedule.
Set hyperparameters by a predefined
epoch-based
schedule.
"""
"""
def
__init__
(
self
,
param
,
schedule
,
interp
=
None
):
def
__init__
(
self
,
param
,
schedule
,
interp
=
None
):
"""
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
Args:
(ep, val) means set the param to "val" after the `ep`th epoch.
param: same as in :class:`HyperParamSetter`.
If epoch == 0, the value is set before training.
schedule(list): with the format ``[(epoch1, val1), (epoch2, val2),
:param interp: None: no interpolation. 'linear': linear interpolation
(epoch3, val3), ...]``.
Each ``(ep, val)`` pair means to set the param
to "val" after the `ep`th epoch.
If ep == 0, the value will be set before training.
interp: None: no interpolation. 'linear': linear interpolation
Example:
.. code-block:: python
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
"""
"""
schedule
=
[(
int
(
a
),
float
(
b
))
for
a
,
b
in
schedule
]
schedule
=
[(
int
(
a
),
float
(
b
))
for
a
,
b
in
schedule
]
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
...
@@ -209,10 +262,20 @@ class ScheduledHyperParamSetter(HyperParamSetter):
...
@@ -209,10 +262,20 @@ class ScheduledHyperParamSetter(HyperParamSetter):
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
""" Set the parameter by a function of epoch num and old value. """
def
__init__
(
self
,
param
,
func
):
def
__init__
(
self
,
param
,
func
):
"""Set hyperparameter by a func
"""
new_value = f(epoch_num, old_value)
Args:
param: same as in :class:`HyperParamSetter`.
func: ``param`` will be set by ``new_value = func(epoch_num, old_value)``.
Example:
Decrease by a factor of 0.9 every two epochs:
.. code-block:: python
HyperParamSetterWithFunc('learning_rate',
lambda e, x: x * 0.9 if e
% 2
== 0 else x)
"""
"""
super
(
HyperParamSetterWithFunc
,
self
)
.
__init__
(
param
)
super
(
HyperParamSetterWithFunc
,
self
)
.
__init__
(
param
)
self
.
f
=
func
self
.
f
=
func
...
@@ -222,22 +285,32 @@ class HyperParamSetterWithFunc(HyperParamSetter):
...
@@ -222,22 +285,32 @@ class HyperParamSetterWithFunc(HyperParamSetter):
class
StatMonitorParamSetter
(
HyperParamSetter
):
class
StatMonitorParamSetter
(
HyperParamSetter
):
"""
Change the param by monitoring the change of a statistic.
Change when it wasn't decreasing/increasing enough.
"""
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
last_k
,
reverse
=
False
last_k
,
reverse
=
False
):
):
"""
"""
Set hyperparameter by a func, when a specific stat wasn't
Args:
decreasing/increasing enough in the last $k$ epochs.
param: same as in :class:`HyperParamSetter`.
Change param by `new_value = value_func(old_value)`,
stat_name (str): name of the statistics.
if :
value_func (float -> float): a function which returns a new value
min(stats) >= stats[0] - threshold, where
taking the old value.
stats = [`stat_nam` in latest `last_k` epochs]
threshold (float): change threshold.
last_k (int): last k epochs.
reverse (bool): monitor increasing instead of decreasing.
This callback will change param by ``new_value = value_func(old_value)``, when:
``min(stats) >= stats[0] - threshold``, where
``stats = [stat_name in last k epochs]``
Example:
If validation error wasn't decreasing for 5 epochs, anneal the learning rate:
For example, if error wasn't decreasing, anneal the learning rate:
.. code-block:: python
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
If reverse==True, use 'increasing' instead of decreasing
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2, 0, 5)
"""
"""
super
(
StatMonitorParamSetter
,
self
)
.
__init__
(
param
)
super
(
StatMonitorParamSetter
,
self
)
.
__init__
(
param
)
self
.
stat_name
=
stat_name
self
.
stat_name
=
stat_name
...
...
tensorpack/callbacks/saver.py
View file @
b5f8c73a
...
@@ -16,22 +16,19 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
...
@@ -16,22 +16,19 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class
ModelSaver
(
Callback
):
class
ModelSaver
(
Callback
):
"""
"""
Save the model to
logger directory
.
Save the model to
``logger.LOG_DIR`` directory every epoch
.
"""
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
None
):
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
):
"""
"""
:param keep_recent: see `tf.train.Saver` documentation.
Args:
:param keep_freq: see `tf.train.Saver` documentation.
keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation.
var_collections (str or list): the variable collection (or list of collections) o save.
"""
"""
self
.
keep_recent
=
keep_recent
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
self
.
keep_freq
=
keep_freq
if
var_collections
is
None
:
try
:
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
except
:
var_collections
=
tf
.
GraphKeys
.
VARIABLES
if
not
isinstance
(
var_collections
,
list
):
if
not
isinstance
(
var_collections
,
list
):
var_collections
=
[
var_collections
]
var_collections
=
[
var_collections
]
self
.
var_collections
=
var_collections
self
.
var_collections
=
var_collections
...
@@ -87,8 +84,25 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
...
@@ -87,8 +84,25 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
class
MinSaver
(
Callback
):
class
MinSaver
(
Callback
):
"""
Separately save the model with minimum value of some statistics.
"""
def
__init__
(
self
,
monitor_stat
,
reverse
=
False
,
filename
=
None
):
"""
Args:
monitor_stat(str): the name of the statistics.
reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``.
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
,
filename
=
None
):
Example:
Save the model with minimum validation error to
"min-val-error.tfmodel" under ``logger.LOG_DIR``:
.. code-block:: python
MinSaver('val-error')
"""
self
.
monitor_stat
=
monitor_stat
self
.
monitor_stat
=
monitor_stat
self
.
reverse
=
reverse
self
.
reverse
=
reverse
self
.
filename
=
filename
self
.
filename
=
filename
...
@@ -128,6 +142,14 @@ class MinSaver(Callback):
...
@@ -128,6 +142,14 @@ class MinSaver(Callback):
class
MaxSaver
(
MinSaver
):
class
MaxSaver
(
MinSaver
):
"""
def
__init__
(
self
,
monitor_stat
):
Separately save the model with maximum value of some statistics.
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
)
"""
def
__init__
(
self
,
monitor_stat
,
filename
=
None
):
"""
Args:
monitor_stat(str): the name of the statistics.
filename (str): the name for the saved model.
Defaults to ``max-{monitor_stat}.tfmodel``.
"""
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
,
filename
=
filename
)
tensorpack/callbacks/stats.py
View file @
b5f8c73a
...
@@ -20,7 +20,8 @@ class StatHolder(object):
...
@@ -20,7 +20,8 @@ class StatHolder(object):
def
__init__
(
self
,
log_dir
):
def
__init__
(
self
,
log_dir
):
"""
"""
:param log_dir: directory to save the stats.
Args:
log_dir(str): directory to save the stats.
"""
"""
self
.
set_print_tag
([])
self
.
set_print_tag
([])
self
.
blacklist_tag
=
set
()
self
.
blacklist_tag
=
set
()
...
@@ -38,19 +39,24 @@ class StatHolder(object):
...
@@ -38,19 +39,24 @@ class StatHolder(object):
def
add_stat
(
self
,
k
,
v
):
def
add_stat
(
self
,
k
,
v
):
"""
"""
Add a stat.
Add a stat.
:param k: name
:param v: value
"""
"""
self
.
stat_now
[
k
]
=
float
(
v
)
self
.
stat_now
[
k
]
=
float
(
v
)
def
set_print_tag
(
self
,
print_tag
):
def
set_print_tag
(
self
,
print_tag
):
"""
"""
Set name of stats to print.
Set name of stats to print.
Args:
print_tag: a collection of string.
"""
"""
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
def
add_blacklist_tag
(
self
,
blacklist_tag
):
def
add_blacklist_tag
(
self
,
blacklist_tag
):
""" Disable printing for some tags """
""" Disable printing for some tags
Args:
blacklist_tag: a collection of string.
"""
self
.
blacklist_tag
|=
set
(
blacklist_tag
)
self
.
blacklist_tag
|=
set
(
blacklist_tag
)
def
get_stat_now
(
self
,
key
):
def
get_stat_now
(
self
,
key
):
...
@@ -60,6 +66,10 @@ class StatHolder(object):
...
@@ -60,6 +66,10 @@ class StatHolder(object):
return
self
.
stat_now
[
key
]
return
self
.
stat_now
[
key
]
def
get_stat_history
(
self
,
key
):
def
get_stat_history
(
self
,
key
):
"""
Returns:
list: all history of a stat.
"""
ret
=
[]
ret
=
[]
for
h
in
self
.
stat_history
:
for
h
in
self
.
stat_history
:
v
=
h
.
get
(
key
,
None
)
v
=
h
.
get
(
key
,
None
)
...
@@ -97,13 +107,14 @@ class StatHolder(object):
...
@@ -97,13 +107,14 @@ class StatHolder(object):
class
StatPrinter
(
Callback
):
class
StatPrinter
(
Callback
):
"""
"""
Control what stats to prin
t.
A callback to control what stats to print. Print everything by defaul
t.
"""
"""
def
__init__
(
self
,
print_tag
=
None
):
def
__init__
(
self
,
print_tag
=
None
):
"""
"""
:param print_tag: a list of regex to match scalar summary to print.
Args:
If None, will print all scalar tags
print_tag: a list of stat names to print.
If None, will print all scalar tags.
"""
"""
self
.
print_tag
=
print_tag
self
.
print_tag
=
print_tag
...
@@ -125,15 +136,25 @@ class StatPrinter(Callback):
...
@@ -125,15 +136,25 @@ class StatPrinter(Callback):
class
SendStat
(
Callback
):
class
SendStat
(
Callback
):
"""
"""
Execute a command with some specific stats.
Execute a command with some specific stats.
For example, send the stats to your phone through pushbullet:
This is useful for, e.g. building a custom statistics monitor.
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes
\
-d type=note -d title="validation error"
\
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
"""
def
__init__
(
self
,
command
,
stats
):
def
__init__
(
self
,
command
,
stats
):
"""
Args:
command(str): a command to execute. Use format string with stat
names as keys.
stats(list or str): stat name(s) to use.
Example:
Send the stats to your phone through pushbullet:
.. code-block:: python
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes
\\
-d type=note -d title="validation error"
\\
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
self
.
command
=
command
self
.
command
=
command
if
not
isinstance
(
stats
,
list
):
if
not
isinstance
(
stats
,
list
):
stats
=
[
stats
]
stats
=
[
stats
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment