Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0fda5f71
Commit
0fda5f71
authored
Mar 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
merge triggerable to callbacks.
parent
c088e2a6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
43 additions
and
71 deletions
+43
-71
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+18
-38
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+2
-2
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+2
-2
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-2
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+2
-2
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+3
-3
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+3
-3
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+3
-12
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+8
-7
No files found.
tensorpack/callbacks/base.py
View file @
0fda5f71
...
...
@@ -3,11 +3,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
import
six
from
..tfutils.common
import
get_op_or_tensor_by_name
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
,
'Triggerable'
]
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
]
@
six
.
add_metaclass
(
ABCMeta
)
...
...
@@ -31,6 +31,7 @@ class Callback(object):
.. automethod:: _after_run
.. automethod:: _trigger_step
.. automethod:: _trigger_epoch
.. automethod:: _trigger
.. automethod:: _after_train
"""
...
...
@@ -111,7 +112,7 @@ class Callback(object):
def
_trigger_step
(
self
):
"""
Called after each :meth:`Trainer.run_step()` completes.
Called after each :meth:`Trainer.run_step()` completes.
Defaults to no-op.
You can override it to implement, e.g. a ProgressBar.
"""
...
...
@@ -122,7 +123,20 @@ class Callback(object):
def
_trigger_epoch
(
self
):
"""
Called after the completion of every epoch.
Called after the completion of every epoch. Defaults to call ``self.trigger()``
"""
self
.
trigger
()
def
trigger
(
self
):
self
.
_trigger
()
def
_trigger
(
self
):
"""
Override this method to define a general trigger behavior, to be used with trigger schedulers.
Note that the schedulers (e.g. :class:`PeriodicTrigger`) might call this
method both inside an epoch and after an epoch.
When used without the scheduler, this method by default will be called by `trigger_epoch()`.
"""
pass
...
...
@@ -147,40 +161,6 @@ class Callback(object):
return
type
(
self
)
.
__name__
@
six
.
add_metaclass
(
ABCMeta
)
class
Triggerable
(
Callback
):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be called either inside an epoch or between epochs.
Other higher-level wrappers will take the responsibility to determine **when**
to call the trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibility and convenience.
.. document private functions
.. automethod:: _trigger
.. automethod:: _trigger_epoch
"""
def
trigger
(
self
):
self
.
_trigger
()
@
abstractmethod
def
_trigger
(
self
):
"""
Override this method to define what to trigger.
Note that this method may be called both inside an epoch and after an epoch.
"""
pass
def
_trigger_epoch
(
self
):
""" If a :class:`Triggerable` is used as a callback directly,
the default behavior is to run the trigger every epoch."""
self
.
trigger
()
class
ProxyCallback
(
Callback
):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
...
...
tensorpack/callbacks/dump.py
View file @
0fda5f71
...
...
@@ -6,14 +6,14 @@ import os
import
cv2
import
numpy
as
np
from
.base
import
Triggerable
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils
import
get_op_tensor_name
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Triggerable
):
class
DumpParamAsImage
(
Callback
):
"""
Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch.
...
...
tensorpack/callbacks/graph.py
View file @
0fda5f71
...
...
@@ -5,12 +5,12 @@
""" Graph related callbacks"""
from
.base
import
Triggerable
from
.base
import
Callback
__all__
=
[
'RunOp'
]
class
RunOp
(
Triggerable
):
class
RunOp
(
Callback
):
""" Run an Op. """
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
...
...
tensorpack/callbacks/inference_runner.py
View file @
0fda5f71
...
...
@@ -20,7 +20,7 @@ from ..tfutils.tower import TowerContext
from
..train.input_data
import
TensorInput
,
FeedInput
from
..predict
import
PredictorTowerBuilder
from
.base
import
Triggerable
from
.base
import
Callback
from
.inference
import
Inferencer
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
...
...
@@ -54,7 +54,7 @@ def summary_inferencer(trainer, infs):
@
six
.
add_metaclass
(
ABCMeta
)
class
InferenceRunnerBase
(
Triggerable
):
class
InferenceRunnerBase
(
Callback
):
""" Base methods for inference runner"""
def
__init__
(
self
,
input
,
infs
,
input_names
=
None
,
prefix
=
''
):
"""
...
...
tensorpack/callbacks/param.py
View file @
0fda5f71
...
...
@@ -9,7 +9,7 @@ import operator
import
six
import
os
from
.base
import
Triggerable
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils
import
get_op_tensor_name
...
...
@@ -107,7 +107,7 @@ class ObjAttrParam(HyperParam):
return
getattr
(
self
.
obj
,
self
.
attrname
)
class
HyperParamSetter
(
Triggerable
):
class
HyperParamSetter
(
Callback
):
"""
An abstract base callback to set hyperparameters.
"""
...
...
tensorpack/callbacks/saver.py
View file @
0fda5f71
...
...
@@ -6,13 +6,13 @@ import tensorflow as tf
import
os
import
shutil
from
.base
import
Triggerable
from
.base
import
Callback
from
..utils
import
logger
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
class
ModelSaver
(
Triggerable
):
class
ModelSaver
(
Callback
):
"""
Save the model every epoch.
"""
...
...
@@ -67,7 +67,7 @@ class ModelSaver(Triggerable):
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
class
MinSaver
(
Triggerable
):
class
MinSaver
(
Callback
):
"""
Separately save the model with minimum value of some statistics.
"""
...
...
tensorpack/callbacks/stats.py
View file @
0fda5f71
...
...
@@ -4,14 +4,14 @@
import
os
from
.base
import
Triggerable
from
.base
import
Callback
from
..utils
import
logger
from
..utils.develop
import
log_deprecated
__all__
=
[
'StatPrinter'
,
'SendStat'
]
class
StatPrinter
(
Triggerable
):
class
StatPrinter
(
Callback
):
def
__init__
(
self
,
print_tag
=
None
):
log_deprecated
(
"StatPrinter"
,
"No need to add StatPrinter to callbacks anymore!"
,
...
...
@@ -19,7 +19,7 @@ class StatPrinter(Triggerable):
# TODO make it into monitor?
class
SendStat
(
Triggerable
):
class
SendStat
(
Callback
):
"""
Execute a command with some specific stats.
This is useful for, e.g. building a custom statistics monitor.
...
...
tensorpack/callbacks/summary.py
View file @
0fda5f71
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
.base
import
Callback
,
Triggerable
from
.base
import
Callback
__all__
=
[
'MovingAverageSummary'
,
'MergeAllSummaries'
]
...
...
@@ -32,7 +32,7 @@ class MovingAverageSummary(Callback):
return
[
self
.
ema_op
]
class
MergeAllSummaries
(
Triggerable
):
class
MergeAllSummaries
(
Callback
):
"""
Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
"""
...
...
@@ -70,15 +70,6 @@ class MergeAllSummaries(Triggerable):
return
self
.
trainer
.
monitors
.
put_summary
(
summary
)
def
_
summary_run_alone
(
self
):
def
_
trigger
(
self
):
summary
=
self
.
summary_op
.
eval
()
self
.
trainer
.
monitors
.
put_summary
(
summary
)
def
_trigger_epoch
(
self
):
if
self
.
_run_alone
:
self
.
_summary_run_alone
()
def
_trigger
(
self
):
assert
self
.
_run_alone
,
\
"MergeAllSummaries can be used as a trigger only if run_alone=True."
self
.
_summary_run_alone
()
tensorpack/callbacks/trigger.py
View file @
0fda5f71
...
...
@@ -3,7 +3,8 @@
# File: trigger.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
.base
import
ProxyCallback
,
Triggerable
from
.base
import
ProxyCallback
,
Callback
from
..utils.develop
import
log_deprecated
__all__
=
[
'PeriodicTrigger'
,
'PeriodicCallback'
]
...
...
@@ -11,12 +12,12 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
class
PeriodicTrigger
(
ProxyCallback
):
"""
Trigger a :class:`Triggerable` callback every k steps or every k epochs
.
Schedule to trigger a callback every k steps or every k epochs by its ``_trigger()`` method
.
"""
def
__init__
(
self
,
triggerable
,
every_k_steps
=
None
,
every_k_epochs
=
None
):
"""
Args:
triggerable (
Triggerable): a Triggerable instance
.
triggerable (
Callback): a Callback instance with a _trigger method to be called
.
every_k_steps (int): trigger when ``local_step
%
k == 0``. Set to
None to disable.
every_k_epochs (int): trigger when ``epoch_num
%
k == 0``. Set to
...
...
@@ -24,7 +25,7 @@ class PeriodicTrigger(ProxyCallback):
every_k_steps and every_k_epochs can be both set, but cannot be both NOne.
"""
assert
isinstance
(
triggerable
,
Triggerable
),
type
(
triggerable
)
assert
isinstance
(
triggerable
,
Callback
),
type
(
triggerable
)
super
(
PeriodicTrigger
,
self
)
.
__init__
(
triggerable
)
assert
(
every_k_epochs
is
not
None
)
or
(
every_k_steps
is
not
None
),
\
"every_k_steps and every_k_epochs cannot be both None!"
...
...
@@ -54,9 +55,8 @@ class PeriodicCallback(ProxyCallback):
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
Note that this wrapper will proxy the :meth:`trigger_step` method as-is.
To schedule a :class:`Triggerable` callback more frequent than once per
epoch, use :class:`PeriodicTrigger` instead.
This wrapper is legacy. It will only proxy the :meth:`trigger_step` method as-is.
To be able to schedule a callback more frequent than once per epoch, use :class:`PeriodicTrigger` instead.
"""
def
__init__
(
self
,
cb
,
period
):
...
...
@@ -67,6 +67,7 @@ class PeriodicCallback(ProxyCallback):
"""
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
self
.
period
=
int
(
period
)
log_deprecated
(
"PeriodicCallback"
,
"Use the more powerful `PeriodicTrigger`."
)
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
%
self
.
period
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment