Commit cbcaef73 authored by Yuxin Wu's avatar Yuxin Wu

StatMonitorParamSetter use last k observations instead of last k in global history (fix #914)

parent da143e0f
......@@ -123,6 +123,11 @@ class Monitors(Callback):
for m in self._monitors:
assert isinstance(m, TrainingMonitor), m
def _setup_graph(self):
# scalar_history's other methods were not called.
# but they are not useful for now
self._scalar_history.setup_graph(self.trainer)
def _dispatch(self, func):
for m in self._monitors:
func(m)
......@@ -204,6 +209,9 @@ class Monitors(Callback):
If you run multiprocess training, keep in mind that
the data is perhaps only available on chief process.
Returns:
a list of (global_step, value) pairs: history data for this scalar
"""
return self._scalar_history.get_history(name)
......@@ -451,7 +459,7 @@ class ScalarPrinter(TrainingMonitor):
class ScalarHistory(TrainingMonitor):
"""
Only used by monitors internally.
Only internally used by monitors.
"""
def __init__(self):
......@@ -459,12 +467,12 @@ class ScalarHistory(TrainingMonitor):
@HIDE_DOC
def process_scalar(self, name, val):
self._dic[name].append(float(val))
self._dic[name].append((self.global_step, float(val)))
def get_latest(self, name):
hist = self._dic[name]
if len(hist) == 0:
raise KeyError("Invalid key: {}".format(name))
raise KeyError("No available data for the key: {}".format(name))
else:
return hist[-1]
......
......@@ -3,6 +3,7 @@
import tensorflow as tf
from collections import deque
from abc import abstractmethod, ABCMeta
import operator
import six
......@@ -109,10 +110,18 @@ class ObjAttrParam(HyperParam):
class HyperParamSetter(Callback):
"""
An abstract base callback to set hyperparameters.
Once the :meth:`trigger()` method is called,
the method :meth:`_get_value_to_set` will be used to get a new value for the hyperparameter.
"""
_chief_only = False
"""
Also enable this hyperparam setter in the :meth:`before_train` method.
"""
_enable_before_train = True
def __init__(self, param):
"""
Args:
......@@ -165,7 +174,8 @@ class HyperParamSetter(Callback):
self._set_param()
def _before_train(self):
self._set_param()
if self._enable_before_train:
self._set_param()
def _set_param(self):
v = self.get_value_to_set()
......@@ -300,9 +310,35 @@ class HyperParamSetterWithFunc(HyperParamSetter):
class StatMonitorParamSetter(HyperParamSetter):
"""
Change the param by monitoring the change of a statistic.
Change when it wasn't decreasing/increasing enough.
Change the param by monitoring the change of a scalar statistics.
The param will be changed when the scalar does not decrease/increase enough.
Once triggered, this callback observes the latest **one** value of ``stat_name``, from the monitor backend.
This callback will then change a hyperparameter ``param`` by ``new_value = value_func(old_value)``, if:
``min(history) >= history[0] - threshold``, where
``history = [the most recent k observations of stat_name]``
Note:
The statistics of interest must be created at a frequency higher than or equal to this callback.
For example, using ``PeriodicTrigger(StatMonitorParamSetter(...), every_k_steps=100)``
is meaningless if the statistics to be monitored is only updated every 500 steps.
Callbacks are executed in order. Therefore, if the statistics to be monitored
is created after this callback, the behavior of this callback may get delayed.
Example:
If validation error wasn't decreasing for 5 epochs, decay the learning rate by 0.2:
.. code-block:: python
StatMonitorParamSetter('learning_rate', 'val-error',
lambda x: x * 0.2, threshold=0, last_k=5)
"""
_enable_before_train = False
def __init__(self, param, stat_name, value_func, threshold,
last_k, reverse=False):
"""
......@@ -312,27 +348,14 @@ class StatMonitorParamSetter(HyperParamSetter):
value_func (float -> float): a function which returns a new value
taking the old value.
threshold (float): change threshold.
last_k (int): last k epochs.
last_k (int): use last k observations of statistics.
reverse (bool): monitor increasing instead of decreasing.
This callback will change ``param`` by ``new_value = value_func(old_value)``, when:
``min(stats) >= stats[0] - threshold``, where
``stats = [the values of stat_name in last k epochs]``
If ``reverse`` is True, it will change the ``param`` when:
``max(stats) <= stats[0] + threshold``.
Example:
If validation error wasn't decreasing for 5 epochs, anneal the learning rate by 0.2:
.. code-block:: python
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2, 0, 5)
If True, ``param`` will be changed when ``max(history) <= history[0] + threshold``.
"""
super(StatMonitorParamSetter, self).__init__(param)
self.stat_name = stat_name
self.value_func = value_func
self.last_k = last_k
self.history = deque(maxlen=last_k)
self.threshold = threshold
self.reverse = reverse
......@@ -340,28 +363,34 @@ class StatMonitorParamSetter(HyperParamSetter):
def _get_value_to_set(self):
try:
hist = self.trainer.monitors.get_history(self.stat_name)
except KeyError:
last = self.trainer.monitors.get_history(self.stat_name)[-1]
except (KeyError, IndexError):
logger.warn(
"[StatMonitorParamSetter] Key {} not found in monitor history! Ignore it.".format(self.stat_name))
"[StatMonitorParamSetter] No history data available for key '{}'.".format(self.stat_name))
return None
if len(self.history) and last[0] == self.history[-1][0]:
logger.warn("StatMonitorParamSetter is triggered, but no new data has been added since last time.")
return None
if len(hist) < self.last_k + 1 or \
self.epoch_num - self.last_changed_epoch < self.last_k:
self.history.append(last)
if len(self.history) < self.history.maxlen or \
self.epoch_num - self.last_changed_epoch < self.history.maxlen:
# not full yet, or value have changed just now
return None
hist = hist[-self.last_k - 1:] # len==last_k+1
hist_first = hist[0]
values = [k[1] for k in self.history]
hist_first = values[0]
if not self.reverse:
hist_min = min(hist)
hist_min = min(values)
if hist_min < hist_first - self.threshold: # small enough
return None
else:
hist_max = max(hist)
hist_max = max(values)
if hist_max > hist_first + self.threshold: # large enough
return None
self.last_changed_epoch = self.epoch_num
logger.info(
"[StatMonitorParamSetter] Triggered, history of {}: ".format(
self.stat_name) + ','.join([str(round(x, 3)) for x in hist]))
self.stat_name) + ','.join([str(round(x, 3)) for x in values]))
return self.value_func(self.get_current_value())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment