Commit cbcaef73 authored by Yuxin Wu's avatar Yuxin Wu

StatMonitorParamSetter use last k observations instead of last k in global history (fix #914)

parent da143e0f
...@@ -123,6 +123,11 @@ class Monitors(Callback): ...@@ -123,6 +123,11 @@ class Monitors(Callback):
for m in self._monitors: for m in self._monitors:
assert isinstance(m, TrainingMonitor), m assert isinstance(m, TrainingMonitor), m
def _setup_graph(self):
# scalar_history's other methods were not called.
# but they are not useful for now
self._scalar_history.setup_graph(self.trainer)
def _dispatch(self, func): def _dispatch(self, func):
for m in self._monitors: for m in self._monitors:
func(m) func(m)
...@@ -204,6 +209,9 @@ class Monitors(Callback): ...@@ -204,6 +209,9 @@ class Monitors(Callback):
If you run multiprocess training, keep in mind that If you run multiprocess training, keep in mind that
the data is perhaps only available on chief process. the data is perhaps only available on chief process.
Returns:
a list of (global_step, value) pairs: history data for this scalar
""" """
return self._scalar_history.get_history(name) return self._scalar_history.get_history(name)
...@@ -451,7 +459,7 @@ class ScalarPrinter(TrainingMonitor): ...@@ -451,7 +459,7 @@ class ScalarPrinter(TrainingMonitor):
class ScalarHistory(TrainingMonitor): class ScalarHistory(TrainingMonitor):
""" """
Only used by monitors internally. Only internally used by monitors.
""" """
def __init__(self): def __init__(self):
...@@ -459,12 +467,12 @@ class ScalarHistory(TrainingMonitor): ...@@ -459,12 +467,12 @@ class ScalarHistory(TrainingMonitor):
@HIDE_DOC @HIDE_DOC
def process_scalar(self, name, val): def process_scalar(self, name, val):
self._dic[name].append(float(val)) self._dic[name].append((self.global_step, float(val)))
def get_latest(self, name): def get_latest(self, name):
hist = self._dic[name] hist = self._dic[name]
if len(hist) == 0: if len(hist) == 0:
raise KeyError("Invalid key: {}".format(name)) raise KeyError("No available data for the key: {}".format(name))
else: else:
return hist[-1] return hist[-1]
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import tensorflow as tf import tensorflow as tf
from collections import deque
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import operator import operator
import six import six
...@@ -109,10 +110,18 @@ class ObjAttrParam(HyperParam): ...@@ -109,10 +110,18 @@ class ObjAttrParam(HyperParam):
class HyperParamSetter(Callback): class HyperParamSetter(Callback):
""" """
An abstract base callback to set hyperparameters. An abstract base callback to set hyperparameters.
Once the :meth:`trigger()` method is called,
the method :meth:`_get_value_to_set` will be used to get a new value for the hyperparameter.
""" """
_chief_only = False _chief_only = False
"""
Also enable this hyperparam setter in the :meth:`before_train` method.
"""
_enable_before_train = True
def __init__(self, param): def __init__(self, param):
""" """
Args: Args:
...@@ -165,7 +174,8 @@ class HyperParamSetter(Callback): ...@@ -165,7 +174,8 @@ class HyperParamSetter(Callback):
self._set_param() self._set_param()
def _before_train(self): def _before_train(self):
self._set_param() if self._enable_before_train:
self._set_param()
def _set_param(self): def _set_param(self):
v = self.get_value_to_set() v = self.get_value_to_set()
...@@ -300,9 +310,35 @@ class HyperParamSetterWithFunc(HyperParamSetter): ...@@ -300,9 +310,35 @@ class HyperParamSetterWithFunc(HyperParamSetter):
class StatMonitorParamSetter(HyperParamSetter): class StatMonitorParamSetter(HyperParamSetter):
""" """
Change the param by monitoring the change of a statistic. Change the param by monitoring the change of a scalar statistics.
Change when it wasn't decreasing/increasing enough. The param will be changed when the scalar does not decrease/increase enough.
Once triggered, this callback observes the latest **one** value of ``stat_name``, from the monitor backend.
This callback will then change a hyperparameter ``param`` by ``new_value = value_func(old_value)``, if:
``min(history) >= history[0] - threshold``, where
``history = [the most recent k observations of stat_name]``
Note:
The statistics of interest must be created at a frequency higher than or equal to this callback.
For example, using ``PeriodicTrigger(StatMonitorParamSetter(...), every_k_steps=100)``
is meaningless if the statistics to be monitored is only updated every 500 steps.
Callbacks are executed in order. Therefore, if the statistics to be monitored
is created after this callback, the behavior of this callback may get delayed.
Example:
If validation error wasn't decreasing for 5 epochs, decay the learning rate by 0.2:
.. code-block:: python
StatMonitorParamSetter('learning_rate', 'val-error',
lambda x: x * 0.2, threshold=0, last_k=5)
""" """
_enable_before_train = False
def __init__(self, param, stat_name, value_func, threshold, def __init__(self, param, stat_name, value_func, threshold,
last_k, reverse=False): last_k, reverse=False):
""" """
...@@ -312,27 +348,14 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -312,27 +348,14 @@ class StatMonitorParamSetter(HyperParamSetter):
value_func (float -> float): a function which returns a new value value_func (float -> float): a function which returns a new value
taking the old value. taking the old value.
threshold (float): change threshold. threshold (float): change threshold.
last_k (int): last k epochs. last_k (int): use last k observations of statistics.
reverse (bool): monitor increasing instead of decreasing. reverse (bool): monitor increasing instead of decreasing.
If True, ``param`` will be changed when ``max(history) <= history[0] + threshold``.
This callback will change ``param`` by ``new_value = value_func(old_value)``, when:
``min(stats) >= stats[0] - threshold``, where
``stats = [the values of stat_name in last k epochs]``
If ``reverse`` is True, it will change the ``param`` when:
``max(stats) <= stats[0] + threshold``.
Example:
If validation error wasn't decreasing for 5 epochs, anneal the learning rate by 0.2:
.. code-block:: python
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2, 0, 5)
""" """
super(StatMonitorParamSetter, self).__init__(param) super(StatMonitorParamSetter, self).__init__(param)
self.stat_name = stat_name self.stat_name = stat_name
self.value_func = value_func self.value_func = value_func
self.last_k = last_k self.history = deque(maxlen=last_k)
self.threshold = threshold self.threshold = threshold
self.reverse = reverse self.reverse = reverse
...@@ -340,28 +363,34 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -340,28 +363,34 @@ class StatMonitorParamSetter(HyperParamSetter):
def _get_value_to_set(self): def _get_value_to_set(self):
try: try:
hist = self.trainer.monitors.get_history(self.stat_name) last = self.trainer.monitors.get_history(self.stat_name)[-1]
except KeyError: except (KeyError, IndexError):
logger.warn( logger.warn(
"[StatMonitorParamSetter] Key {} not found in monitor history! Ignore it.".format(self.stat_name)) "[StatMonitorParamSetter] No history data available for key '{}'.".format(self.stat_name))
return None
if len(self.history) and last[0] == self.history[-1][0]:
logger.warn("StatMonitorParamSetter is triggered, but no new data has been added since last time.")
return None return None
if len(hist) < self.last_k + 1 or \ self.history.append(last)
self.epoch_num - self.last_changed_epoch < self.last_k:
if len(self.history) < self.history.maxlen or \
self.epoch_num - self.last_changed_epoch < self.history.maxlen:
# not full yet, or value have changed just now
return None return None
hist = hist[-self.last_k - 1:] # len==last_k+1
hist_first = hist[0] values = [k[1] for k in self.history]
hist_first = values[0]
if not self.reverse: if not self.reverse:
hist_min = min(hist) hist_min = min(values)
if hist_min < hist_first - self.threshold: # small enough if hist_min < hist_first - self.threshold: # small enough
return None return None
else: else:
hist_max = max(hist) hist_max = max(values)
if hist_max > hist_first + self.threshold: # large enough if hist_max > hist_first + self.threshold: # large enough
return None return None
self.last_changed_epoch = self.epoch_num self.last_changed_epoch = self.epoch_num
logger.info( logger.info(
"[StatMonitorParamSetter] Triggered, history of {}: ".format( "[StatMonitorParamSetter] Triggered, history of {}: ".format(
self.stat_name) + ','.join([str(round(x, 3)) for x in hist])) self.stat_name) + ','.join([str(round(x, 3)) for x in values]))
return self.value_func(self.get_current_value()) return self.value_func(self.get_current_value())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment