Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
cbcaef73
Commit
cbcaef73
authored
Sep 30, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
StatMonitorParamSetter use last k observations instead of last k in global history (fix #914)
parent
da143e0f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
32 deletions
+69
-32
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+11
-3
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+58
-29
No files found.
tensorpack/callbacks/monitor.py
View file @
cbcaef73
...
@@ -123,6 +123,11 @@ class Monitors(Callback):
...
@@ -123,6 +123,11 @@ class Monitors(Callback):
for
m
in
self
.
_monitors
:
for
m
in
self
.
_monitors
:
assert
isinstance
(
m
,
TrainingMonitor
),
m
assert
isinstance
(
m
,
TrainingMonitor
),
m
def
_setup_graph
(
self
):
# scalar_history's other methods were not called.
# but they are not useful for now
self
.
_scalar_history
.
setup_graph
(
self
.
trainer
)
def
_dispatch
(
self
,
func
):
def
_dispatch
(
self
,
func
):
for
m
in
self
.
_monitors
:
for
m
in
self
.
_monitors
:
func
(
m
)
func
(
m
)
...
@@ -204,6 +209,9 @@ class Monitors(Callback):
...
@@ -204,6 +209,9 @@ class Monitors(Callback):
If you run multiprocess training, keep in mind that
If you run multiprocess training, keep in mind that
the data is perhaps only available on chief process.
the data is perhaps only available on chief process.
Returns:
a list of (global_step, value) pairs: history data for this scalar
"""
"""
return
self
.
_scalar_history
.
get_history
(
name
)
return
self
.
_scalar_history
.
get_history
(
name
)
...
@@ -451,7 +459,7 @@ class ScalarPrinter(TrainingMonitor):
...
@@ -451,7 +459,7 @@ class ScalarPrinter(TrainingMonitor):
class
ScalarHistory
(
TrainingMonitor
):
class
ScalarHistory
(
TrainingMonitor
):
"""
"""
Only
used by monitors internally
.
Only
internally used by monitors
.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -459,12 +467,12 @@ class ScalarHistory(TrainingMonitor):
...
@@ -459,12 +467,12 @@ class ScalarHistory(TrainingMonitor):
@
HIDE_DOC
@
HIDE_DOC
def
process_scalar
(
self
,
name
,
val
):
def
process_scalar
(
self
,
name
,
val
):
self
.
_dic
[
name
]
.
append
(
float
(
val
))
self
.
_dic
[
name
]
.
append
(
(
self
.
global_step
,
float
(
val
)
))
def
get_latest
(
self
,
name
):
def
get_latest
(
self
,
name
):
hist
=
self
.
_dic
[
name
]
hist
=
self
.
_dic
[
name
]
if
len
(
hist
)
==
0
:
if
len
(
hist
)
==
0
:
raise
KeyError
(
"
Invalid
key: {}"
.
format
(
name
))
raise
KeyError
(
"
No available data for the
key: {}"
.
format
(
name
))
else
:
else
:
return
hist
[
-
1
]
return
hist
[
-
1
]
...
...
tensorpack/callbacks/param.py
View file @
cbcaef73
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
deque
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
import
operator
import
operator
import
six
import
six
...
@@ -109,10 +110,18 @@ class ObjAttrParam(HyperParam):
...
@@ -109,10 +110,18 @@ class ObjAttrParam(HyperParam):
class
HyperParamSetter
(
Callback
):
class
HyperParamSetter
(
Callback
):
"""
"""
An abstract base callback to set hyperparameters.
An abstract base callback to set hyperparameters.
Once the :meth:`trigger()` method is called,
the method :meth:`_get_value_to_set` will be used to get a new value for the hyperparameter.
"""
"""
_chief_only
=
False
_chief_only
=
False
"""
Also enable this hyperparam setter in the :meth:`before_train` method.
"""
_enable_before_train
=
True
def
__init__
(
self
,
param
):
def
__init__
(
self
,
param
):
"""
"""
Args:
Args:
...
@@ -165,7 +174,8 @@ class HyperParamSetter(Callback):
...
@@ -165,7 +174,8 @@ class HyperParamSetter(Callback):
self
.
_set_param
()
self
.
_set_param
()
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_set_param
()
if
self
.
_enable_before_train
:
self
.
_set_param
()
def
_set_param
(
self
):
def
_set_param
(
self
):
v
=
self
.
get_value_to_set
()
v
=
self
.
get_value_to_set
()
...
@@ -300,9 +310,35 @@ class HyperParamSetterWithFunc(HyperParamSetter):
...
@@ -300,9 +310,35 @@ class HyperParamSetterWithFunc(HyperParamSetter):
class
StatMonitorParamSetter
(
HyperParamSetter
):
class
StatMonitorParamSetter
(
HyperParamSetter
):
"""
"""
Change the param by monitoring the change of a statistic.
Change the param by monitoring the change of a scalar statistics.
Change when it wasn't decreasing/increasing enough.
The param will be changed when the scalar does not decrease/increase enough.
Once triggered, this callback observes the latest **one** value of ``stat_name``, from the monitor backend.
This callback will then change a hyperparameter ``param`` by ``new_value = value_func(old_value)``, if:
``min(history) >= history[0] - threshold``, where
``history = [the most recent k observations of stat_name]``
Note:
The statistics of interest must be created at a frequency higher than or equal to this callback.
For example, using ``PeriodicTrigger(StatMonitorParamSetter(...), every_k_steps=100)``
is meaningless if the statistics to be monitored is only updated every 500 steps.
Callbacks are executed in order. Therefore, if the statistics to be monitored
is created after this callback, the behavior of this callback may get delayed.
Example:
If validation error wasn't decreasing for 5 epochs, decay the learning rate by 0.2:
.. code-block:: python
StatMonitorParamSetter('learning_rate', 'val-error',
lambda x: x * 0.2, threshold=0, last_k=5)
"""
"""
_enable_before_train
=
False
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
last_k
,
reverse
=
False
):
last_k
,
reverse
=
False
):
"""
"""
...
@@ -312,27 +348,14 @@ class StatMonitorParamSetter(HyperParamSetter):
...
@@ -312,27 +348,14 @@ class StatMonitorParamSetter(HyperParamSetter):
value_func (float -> float): a function which returns a new value
value_func (float -> float): a function which returns a new value
taking the old value.
taking the old value.
threshold (float): change threshold.
threshold (float): change threshold.
last_k (int):
last k epoch
s.
last_k (int):
use last k observations of statistic
s.
reverse (bool): monitor increasing instead of decreasing.
reverse (bool): monitor increasing instead of decreasing.
If True, ``param`` will be changed when ``max(history) <= history[0] + threshold``.
This callback will change ``param`` by ``new_value = value_func(old_value)``, when:
``min(stats) >= stats[0] - threshold``, where
``stats = [the values of stat_name in last k epochs]``
If ``reverse`` is True, it will change the ``param`` when:
``max(stats) <= stats[0] + threshold``.
Example:
If validation error wasn't decreasing for 5 epochs, anneal the learning rate by 0.2:
.. code-block:: python
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2, 0, 5)
"""
"""
super
(
StatMonitorParamSetter
,
self
)
.
__init__
(
param
)
super
(
StatMonitorParamSetter
,
self
)
.
__init__
(
param
)
self
.
stat_name
=
stat_name
self
.
stat_name
=
stat_name
self
.
value_func
=
value_func
self
.
value_func
=
value_func
self
.
last_k
=
last_k
self
.
history
=
deque
(
maxlen
=
last_k
)
self
.
threshold
=
threshold
self
.
threshold
=
threshold
self
.
reverse
=
reverse
self
.
reverse
=
reverse
...
@@ -340,28 +363,34 @@ class StatMonitorParamSetter(HyperParamSetter):
...
@@ -340,28 +363,34 @@ class StatMonitorParamSetter(HyperParamSetter):
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
try
:
try
:
hist
=
self
.
trainer
.
monitors
.
get_history
(
self
.
stat_name
)
last
=
self
.
trainer
.
monitors
.
get_history
(
self
.
stat_name
)[
-
1
]
except
KeyError
:
except
(
KeyError
,
IndexError
)
:
logger
.
warn
(
logger
.
warn
(
"[StatMonitorParamSetter] Key {} not found in monitor history! Ignore it."
.
format
(
self
.
stat_name
))
"[StatMonitorParamSetter] No history data available for key '{}'."
.
format
(
self
.
stat_name
))
return
None
if
len
(
self
.
history
)
and
last
[
0
]
==
self
.
history
[
-
1
][
0
]:
logger
.
warn
(
"StatMonitorParamSetter is triggered, but no new data has been added since last time."
)
return
None
return
None
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
self
.
history
.
append
(
last
)
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
if
len
(
self
.
history
)
<
self
.
history
.
maxlen
or
\
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
history
.
maxlen
:
# not full yet, or value have changed just now
return
None
return
None
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
hist_first
=
hist
[
0
]
values
=
[
k
[
1
]
for
k
in
self
.
history
]
hist_first
=
values
[
0
]
if
not
self
.
reverse
:
if
not
self
.
reverse
:
hist_min
=
min
(
hist
)
hist_min
=
min
(
values
)
if
hist_min
<
hist_first
-
self
.
threshold
:
# small enough
if
hist_min
<
hist_first
-
self
.
threshold
:
# small enough
return
None
return
None
else
:
else
:
hist_max
=
max
(
hist
)
hist_max
=
max
(
values
)
if
hist_max
>
hist_first
+
self
.
threshold
:
# large enough
if
hist_max
>
hist_first
+
self
.
threshold
:
# large enough
return
None
return
None
self
.
last_changed_epoch
=
self
.
epoch_num
self
.
last_changed_epoch
=
self
.
epoch_num
logger
.
info
(
logger
.
info
(
"[StatMonitorParamSetter] Triggered, history of {}: "
.
format
(
"[StatMonitorParamSetter] Triggered, history of {}: "
.
format
(
self
.
stat_name
)
+
','
.
join
([
str
(
round
(
x
,
3
))
for
x
in
hist
]))
self
.
stat_name
)
+
','
.
join
([
str
(
round
(
x
,
3
))
for
x
in
values
]))
return
self
.
value_func
(
self
.
get_current_value
())
return
self
.
value_func
(
self
.
get_current_value
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment