Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0fda5f71
Commit
0fda5f71
authored
Mar 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
merge triggerable to callbacks.
parent
c088e2a6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
43 additions
and
71 deletions
+43
-71
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+18
-38
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+2
-2
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+2
-2
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-2
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+2
-2
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+3
-3
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+3
-3
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+3
-12
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+8
-7
No files found.
tensorpack/callbacks/base.py
View file @
0fda5f71
...
@@ -3,11 +3,11 @@
...
@@ -3,11 +3,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
import
six
import
six
from
..tfutils.common
import
get_op_or_tensor_by_name
from
..tfutils.common
import
get_op_or_tensor_by_name
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
,
'Triggerable'
]
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
...
@@ -31,6 +31,7 @@ class Callback(object):
...
@@ -31,6 +31,7 @@ class Callback(object):
.. automethod:: _after_run
.. automethod:: _after_run
.. automethod:: _trigger_step
.. automethod:: _trigger_step
.. automethod:: _trigger_epoch
.. automethod:: _trigger_epoch
.. automethod:: _trigger
.. automethod:: _after_train
.. automethod:: _after_train
"""
"""
...
@@ -111,7 +112,7 @@ class Callback(object):
...
@@ -111,7 +112,7 @@ class Callback(object):
def
_trigger_step
(
self
):
def
_trigger_step
(
self
):
"""
"""
Called after each :meth:`Trainer.run_step()` completes.
Called after each :meth:`Trainer.run_step()` completes.
Defaults to no-op.
You can override it to implement, e.g. a ProgressBar.
You can override it to implement, e.g. a ProgressBar.
"""
"""
...
@@ -122,7 +123,20 @@ class Callback(object):
...
@@ -122,7 +123,20 @@ class Callback(object):
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
"""
"""
Called after the completion of every epoch.
Called after the completion of every epoch. Defaults to call ``self.trigger()``
"""
self
.
trigger
()
def
trigger
(
self
):
self
.
_trigger
()
def
_trigger
(
self
):
"""
Override this method to define a general trigger behavior, to be used with trigger schedulers.
Note that the schedulers (e.g. :class:`PeriodicTrigger`) might call this
method both inside an epoch and after an epoch.
When used without the scheduler, this method by default will be called by `trigger_epoch()`.
"""
"""
pass
pass
...
@@ -147,40 +161,6 @@ class Callback(object):
...
@@ -147,40 +161,6 @@ class Callback(object):
return
type
(
self
)
.
__name__
return
type
(
self
)
.
__name__
@
six
.
add_metaclass
(
ABCMeta
)
class
Triggerable
(
Callback
):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be called either inside an epoch or between epochs.
Other higher-level wrappers will take the responsibility to determine **when**
to call the trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibility and convenience.
.. document private functions
.. automethod:: _trigger
.. automethod:: _trigger_epoch
"""
def
trigger
(
self
):
self
.
_trigger
()
@
abstractmethod
def
_trigger
(
self
):
"""
Override this method to define what to trigger.
Note that this method may be called both inside an epoch and after an epoch.
"""
pass
def
_trigger_epoch
(
self
):
""" If a :class:`Triggerable` is used as a callback directly,
the default behavior is to run the trigger every epoch."""
self
.
trigger
()
class
ProxyCallback
(
Callback
):
class
ProxyCallback
(
Callback
):
""" A callback which proxy all methods to another callback.
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
It's useful as a base class of callbacks which decorate other callbacks.
...
...
tensorpack/callbacks/dump.py
View file @
0fda5f71
...
@@ -6,14 +6,14 @@ import os
...
@@ -6,14 +6,14 @@ import os
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
from
.base
import
Triggerable
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_op_tensor_name
from
..tfutils
import
get_op_tensor_name
__all__
=
[
'DumpParamAsImage'
]
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Triggerable
):
class
DumpParamAsImage
(
Callback
):
"""
"""
Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch.
Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch.
...
...
tensorpack/callbacks/graph.py
View file @
0fda5f71
...
@@ -5,12 +5,12 @@
...
@@ -5,12 +5,12 @@
""" Graph related callbacks"""
""" Graph related callbacks"""
from
.base
import
Triggerable
from
.base
import
Callback
__all__
=
[
'RunOp'
]
__all__
=
[
'RunOp'
]
class
RunOp
(
Triggerable
):
class
RunOp
(
Callback
):
""" Run an Op. """
""" Run an Op. """
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
...
...
tensorpack/callbacks/inference_runner.py
View file @
0fda5f71
...
@@ -20,7 +20,7 @@ from ..tfutils.tower import TowerContext
...
@@ -20,7 +20,7 @@ from ..tfutils.tower import TowerContext
from
..train.input_data
import
TensorInput
,
FeedInput
from
..train.input_data
import
TensorInput
,
FeedInput
from
..predict
import
PredictorTowerBuilder
from
..predict
import
PredictorTowerBuilder
from
.base
import
Triggerable
from
.base
import
Callback
from
.inference
import
Inferencer
from
.inference
import
Inferencer
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
...
@@ -54,7 +54,7 @@ def summary_inferencer(trainer, infs):
...
@@ -54,7 +54,7 @@ def summary_inferencer(trainer, infs):
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
InferenceRunnerBase
(
Triggerable
):
class
InferenceRunnerBase
(
Callback
):
""" Base methods for inference runner"""
""" Base methods for inference runner"""
def
__init__
(
self
,
input
,
infs
,
input_names
=
None
,
prefix
=
''
):
def
__init__
(
self
,
input
,
infs
,
input_names
=
None
,
prefix
=
''
):
"""
"""
...
...
tensorpack/callbacks/param.py
View file @
0fda5f71
...
@@ -9,7 +9,7 @@ import operator
...
@@ -9,7 +9,7 @@ import operator
import
six
import
six
import
os
import
os
from
.base
import
Triggerable
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_op_tensor_name
from
..tfutils
import
get_op_tensor_name
...
@@ -107,7 +107,7 @@ class ObjAttrParam(HyperParam):
...
@@ -107,7 +107,7 @@ class ObjAttrParam(HyperParam):
return
getattr
(
self
.
obj
,
self
.
attrname
)
return
getattr
(
self
.
obj
,
self
.
attrname
)
class
HyperParamSetter
(
Triggerable
):
class
HyperParamSetter
(
Callback
):
"""
"""
An abstract base callback to set hyperparameters.
An abstract base callback to set hyperparameters.
"""
"""
...
...
tensorpack/callbacks/saver.py
View file @
0fda5f71
...
@@ -6,13 +6,13 @@ import tensorflow as tf
...
@@ -6,13 +6,13 @@ import tensorflow as tf
import
os
import
os
import
shutil
import
shutil
from
.base
import
Triggerable
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
class
ModelSaver
(
Triggerable
):
class
ModelSaver
(
Callback
):
"""
"""
Save the model every epoch.
Save the model every epoch.
"""
"""
...
@@ -67,7 +67,7 @@ class ModelSaver(Triggerable):
...
@@ -67,7 +67,7 @@ class ModelSaver(Triggerable):
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
class
MinSaver
(
Triggerable
):
class
MinSaver
(
Callback
):
"""
"""
Separately save the model with minimum value of some statistics.
Separately save the model with minimum value of some statistics.
"""
"""
...
...
tensorpack/callbacks/stats.py
View file @
0fda5f71
...
@@ -4,14 +4,14 @@
...
@@ -4,14 +4,14 @@
import
os
import
os
from
.base
import
Triggerable
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
__all__
=
[
'StatPrinter'
,
'SendStat'
]
__all__
=
[
'StatPrinter'
,
'SendStat'
]
class
StatPrinter
(
Triggerable
):
class
StatPrinter
(
Callback
):
def
__init__
(
self
,
print_tag
=
None
):
def
__init__
(
self
,
print_tag
=
None
):
log_deprecated
(
"StatPrinter"
,
log_deprecated
(
"StatPrinter"
,
"No need to add StatPrinter to callbacks anymore!"
,
"No need to add StatPrinter to callbacks anymore!"
,
...
@@ -19,7 +19,7 @@ class StatPrinter(Triggerable):
...
@@ -19,7 +19,7 @@ class StatPrinter(Triggerable):
# TODO make it into monitor?
# TODO make it into monitor?
class
SendStat
(
Triggerable
):
class
SendStat
(
Callback
):
"""
"""
Execute a command with some specific stats.
Execute a command with some specific stats.
This is useful for, e.g. building a custom statistics monitor.
This is useful for, e.g. building a custom statistics monitor.
...
...
tensorpack/callbacks/summary.py
View file @
0fda5f71
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
.base
import
Callback
,
Triggerable
from
.base
import
Callback
__all__
=
[
'MovingAverageSummary'
,
'MergeAllSummaries'
]
__all__
=
[
'MovingAverageSummary'
,
'MergeAllSummaries'
]
...
@@ -32,7 +32,7 @@ class MovingAverageSummary(Callback):
...
@@ -32,7 +32,7 @@ class MovingAverageSummary(Callback):
return
[
self
.
ema_op
]
return
[
self
.
ema_op
]
class
MergeAllSummaries
(
Triggerable
):
class
MergeAllSummaries
(
Callback
):
"""
"""
Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
"""
"""
...
@@ -70,15 +70,6 @@ class MergeAllSummaries(Triggerable):
...
@@ -70,15 +70,6 @@ class MergeAllSummaries(Triggerable):
return
return
self
.
trainer
.
monitors
.
put_summary
(
summary
)
self
.
trainer
.
monitors
.
put_summary
(
summary
)
def
_
summary_run_alone
(
self
):
def
_
trigger
(
self
):
summary
=
self
.
summary_op
.
eval
()
summary
=
self
.
summary_op
.
eval
()
self
.
trainer
.
monitors
.
put_summary
(
summary
)
self
.
trainer
.
monitors
.
put_summary
(
summary
)
def
_trigger_epoch
(
self
):
if
self
.
_run_alone
:
self
.
_summary_run_alone
()
def
_trigger
(
self
):
assert
self
.
_run_alone
,
\
"MergeAllSummaries can be used as a trigger only if run_alone=True."
self
.
_summary_run_alone
()
tensorpack/callbacks/trigger.py
View file @
0fda5f71
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# File: trigger.py
# File: trigger.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
.base
import
ProxyCallback
,
Triggerable
from
.base
import
ProxyCallback
,
Callback
from
..utils.develop
import
log_deprecated
__all__
=
[
'PeriodicTrigger'
,
'PeriodicCallback'
]
__all__
=
[
'PeriodicTrigger'
,
'PeriodicCallback'
]
...
@@ -11,12 +12,12 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
...
@@ -11,12 +12,12 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
class
PeriodicTrigger
(
ProxyCallback
):
class
PeriodicTrigger
(
ProxyCallback
):
"""
"""
Trigger a :class:`Triggerable` callback every k steps or every k epochs
.
Schedule to trigger a callback every k steps or every k epochs by its ``_trigger()`` method
.
"""
"""
def
__init__
(
self
,
triggerable
,
every_k_steps
=
None
,
every_k_epochs
=
None
):
def
__init__
(
self
,
triggerable
,
every_k_steps
=
None
,
every_k_epochs
=
None
):
"""
"""
Args:
Args:
triggerable (
Triggerable): a Triggerable instance
.
triggerable (
Callback): a Callback instance with a _trigger method to be called
.
every_k_steps (int): trigger when ``local_step
%
k == 0``. Set to
every_k_steps (int): trigger when ``local_step
%
k == 0``. Set to
None to disable.
None to disable.
every_k_epochs (int): trigger when ``epoch_num
%
k == 0``. Set to
every_k_epochs (int): trigger when ``epoch_num
%
k == 0``. Set to
...
@@ -24,7 +25,7 @@ class PeriodicTrigger(ProxyCallback):
...
@@ -24,7 +25,7 @@ class PeriodicTrigger(ProxyCallback):
every_k_steps and every_k_epochs can be both set, but cannot be both NOne.
every_k_steps and every_k_epochs can be both set, but cannot be both NOne.
"""
"""
assert
isinstance
(
triggerable
,
Triggerable
),
type
(
triggerable
)
assert
isinstance
(
triggerable
,
Callback
),
type
(
triggerable
)
super
(
PeriodicTrigger
,
self
)
.
__init__
(
triggerable
)
super
(
PeriodicTrigger
,
self
)
.
__init__
(
triggerable
)
assert
(
every_k_epochs
is
not
None
)
or
(
every_k_steps
is
not
None
),
\
assert
(
every_k_epochs
is
not
None
)
or
(
every_k_steps
is
not
None
),
\
"every_k_steps and every_k_epochs cannot be both None!"
"every_k_steps and every_k_epochs cannot be both None!"
...
@@ -54,9 +55,8 @@ class PeriodicCallback(ProxyCallback):
...
@@ -54,9 +55,8 @@ class PeriodicCallback(ProxyCallback):
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
method is called.
Note that this wrapper will proxy the :meth:`trigger_step` method as-is.
This wrapper is legacy. It will only proxy the :meth:`trigger_step` method as-is.
To schedule a :class:`Triggerable` callback more frequent than once per
To be able to schedule a callback more frequent than once per epoch, use :class:`PeriodicTrigger` instead.
epoch, use :class:`PeriodicTrigger` instead.
"""
"""
def
__init__
(
self
,
cb
,
period
):
def
__init__
(
self
,
cb
,
period
):
...
@@ -67,6 +67,7 @@ class PeriodicCallback(ProxyCallback):
...
@@ -67,6 +67,7 @@ class PeriodicCallback(ProxyCallback):
"""
"""
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
self
.
period
=
int
(
period
)
self
.
period
=
int
(
period
)
log_deprecated
(
"PeriodicCallback"
,
"Use the more powerful `PeriodicTrigger`."
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
%
self
.
period
==
0
:
if
self
.
epoch_num
%
self
.
period
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment