Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3657bbd7
Commit
3657bbd7
authored
Jan 28, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make more callbacks triggerable
parent
bc0b7c63
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
85 additions
and
84 deletions
+85
-84
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+35
-30
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+1
-1
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+1
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+5
-5
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-1
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+3
-4
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+3
-4
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+36
-38
No files found.
tensorpack/callbacks/base.py
View file @
3657bbd7
...
...
@@ -3,11 +3,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
abc
import
ABCMeta
from
abc
import
ABCMeta
,
abstractmethod
import
six
from
..tfutils.common
import
get_op_or_tensor_by_name
,
get_global_step_value
__all__
=
[
'Callback'
,
'P
eriodicCallback'
,
'ProxyCallback'
,
'CallbackFactory
'
]
__all__
=
[
'Callback'
,
'P
roxyCallback'
,
'CallbackFactory'
,
'Triggerable
'
]
@
six
.
add_metaclass
(
ABCMeta
)
...
...
@@ -128,6 +128,39 @@ class Callback(object):
return
type
(
self
)
.
__name__
@
six
.
add_metaclass
(
ABCMeta
)
class
Triggerable
(
Callback
):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be triggered either inside an epoch or between epochs.
The higher-level wrapper will take the responsibility to determine when
to trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibilty and convenience.
"""
def
trigger
(
self
):
"""
Trigger something.
Note that this method may be called both inside an epoch and after an epoch.
Some operations (e.g. writing scalar stats) currently will cause
problems if run inside an epoch. This will be fixed in the future.
"""
# TODO
self
.
_trigger
()
@
abstractmethod
def
_trigger
(
self
):
pass
def
_trigger_epoch
(
self
):
""" If used as a callback directly, run the trigger every epoch."""
self
.
trigger
()
class
ProxyCallback
(
Callback
):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
...
...
@@ -160,34 +193,6 @@ class ProxyCallback(Callback):
return
"Proxy-"
+
str
(
self
.
cb
)
class
PeriodicCallback
(
ProxyCallback
):
"""
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
Note that this method will proxy the :meth:`trigger_step` method as-is.
"""
def
__init__
(
self
,
cb
,
period
):
"""
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
self
.
period
=
int
(
period
)
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
%
self
.
period
==
0
:
self
.
cb
.
trigger_epoch
()
def
__str__
(
self
):
return
"Periodic-"
+
str
(
self
.
cb
)
class
CallbackFactory
(
Callback
):
"""
Create a callback with some lambdas.
...
...
tensorpack/callbacks/dump.py
View file @
3657bbd7
...
...
@@ -6,7 +6,7 @@ import os
import
cv2
import
numpy
as
np
from
.
trigger
import
Triggerable
from
.
base
import
Triggerable
from
..utils
import
logger
from
..tfutils
import
get_op_tensor_name
...
...
tensorpack/callbacks/graph.py
View file @
3657bbd7
...
...
@@ -5,7 +5,7 @@
""" Graph related callbacks"""
from
.
trigger
import
Triggerable
from
.
base
import
Triggerable
__all__
=
[
'RunOp'
]
...
...
tensorpack/callbacks/inference_runner.py
View file @
3657bbd7
...
...
@@ -16,7 +16,7 @@ from ..tfutils import TowerContext
from
..train.input_data
import
FeedfreeInput
from
..predict
import
build_prediction_graph
from
.base
import
Callback
from
.base
import
Triggerable
from
.inference
import
Inferencer
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
]
...
...
@@ -63,7 +63,7 @@ def summary_inferencer(trainer, infs):
trainer
.
add_scalar_summary
(
k
,
v
)
class
InferenceRunner
(
Callback
):
class
InferenceRunner
(
Triggerable
):
"""
A callback that runs a list of :class:`Inferencer` on some
:class:`DataFlow`.
...
...
@@ -128,7 +128,7 @@ class InferenceRunner(Callback):
self
.
inf_to_tensors
=
[
find_tensors
(
t
)
for
t
in
dispatcer
.
get_names_for_each_entry
()]
# list of list of IOTensor
def
_trigger
_epoch
(
self
):
def
_trigger
(
self
):
for
inf
in
self
.
infs
:
inf
.
before_inference
()
...
...
@@ -147,7 +147,7 @@ class InferenceRunner(Callback):
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
class
FeedfreeInferenceRunner
(
Callback
):
class
FeedfreeInferenceRunner
(
Triggerable
):
""" A callback that runs a list of :class:`Inferencer` on some
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
pipeline.
...
...
@@ -231,7 +231,7 @@ class FeedfreeInferenceRunner(Callback):
# list of list of id
self
.
inf_to_idxs
=
dispatcer
.
get_idx_for_each_entry
()
def
_trigger
_epoch
(
self
):
def
_trigger
(
self
):
for
inf
in
self
.
infs
:
inf
.
before_inference
()
...
...
tensorpack/callbacks/param.py
View file @
3657bbd7
...
...
@@ -9,7 +9,7 @@ import operator
import
six
import
os
from
.
trigger
import
Triggerable
from
.
base
import
Triggerable
from
..utils
import
logger
from
..tfutils
import
get_op_tensor_name
...
...
tensorpack/callbacks/saver.py
View file @
3657bbd7
...
...
@@ -6,10 +6,9 @@ import tensorflow as tf
import
os
import
shutil
from
.base
import
Callback
from
.base
import
Triggerable
from
..utils
import
logger
from
..tfutils.varmanip
import
get_savename_from_varname
from
.trigger
import
Triggerable
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
...
@@ -83,7 +82,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
class
MinSaver
(
Callback
):
class
MinSaver
(
Triggerable
):
"""
Separately save the model with minimum value of some statistics.
"""
...
...
@@ -126,7 +125,7 @@ class MinSaver(Callback):
return
False
return
v
>
self
.
min
if
self
.
reverse
else
v
<
self
.
min
def
_trigger
_epoch
(
self
):
def
_trigger
(
self
):
if
self
.
min
is
None
or
self
.
_need_save
():
self
.
min
=
self
.
_get_stat
()
if
self
.
min
:
...
...
tensorpack/callbacks/stats.py
View file @
3657bbd7
...
...
@@ -6,8 +6,7 @@ import os
import
operator
import
json
from
.base
import
Callback
from
.trigger
import
Triggerable
from
.base
import
Triggerable
from
..utils
import
logger
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
...
...
@@ -110,7 +109,7 @@ class StatHolder(object):
logger
.
exception
(
"Exception in StatHolder.finalize()!"
)
class
StatPrinter
(
Callback
):
class
StatPrinter
(
Triggerable
):
"""
A callback to control what stats to print. Enable by default to print
everything in trainer.stat_holder.
...
...
@@ -132,7 +131,7 @@ class StatPrinter(Callback):
# just try to add this stat earlier so SendStat can use
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
def
_trigger
_epoch
(
self
):
def
_trigger
(
self
):
# by default, add this two stat
self
.
_stat_holder
.
add_stat
(
'global_step'
,
self
.
global_step
)
self
.
_stat_holder
.
finalize
()
...
...
tensorpack/callbacks/trigger.py
View file @
3657bbd7
...
...
@@ -3,46 +3,10 @@
# File: trigger.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
abstractmethod
,
ABCMeta
import
six
from
.base
import
ProxyCallback
,
Triggerable
from
.base
import
Callback
,
ProxyCallback
__all__
=
[
'Triggerable'
,
'PeriodicTrigger'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
Triggerable
(
Callback
):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be triggered either inside an epoch or between epochs.
The higher-level wrapper will take the responsibility to determine when
to trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibilty and convenience.
"""
def
trigger
(
self
):
"""
Trigger something.
Note that this method may be called both inside an epoch and after an epoch.
Some operations (e.g. writing scalar stats) currently will cause
problems if run inside an epoch. This will be fixed in the future.
"""
# TODO
self
.
_trigger
()
@
abstractmethod
def
_trigger
(
self
):
pass
def
_trigger_epoch
(
self
):
""" If used as a callback directly, run the trigger every epoch."""
self
.
trigger
()
__all__
=
[
'PeriodicTrigger'
,
'PeriodicCallback'
]
class
PeriodicTrigger
(
ProxyCallback
):
...
...
@@ -78,3 +42,37 @@ class PeriodicTrigger(ProxyCallback):
return
if
self
.
epoch_num
%
self
.
_epoch_k
==
0
:
self
.
cb
.
trigger
()
def
__str__
(
self
):
return
"PeriodicTrigger-"
+
str
(
self
.
cb
)
class
PeriodicCallback
(
ProxyCallback
):
"""
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
Note that this wrapper will proxy the :meth:`trigger_step` method as-is.
To schedule a :class:`Triggerable` callback more frequent than once per
epoch, use :class:`PeriodicTrigger` instead.
"""
def
__init__
(
self
,
cb
,
period
):
"""
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
self
.
period
=
int
(
period
)
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
%
self
.
period
==
0
:
self
.
cb
.
trigger_epoch
()
def
__str__
(
self
):
return
"Periodic-"
+
str
(
self
.
cb
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment