Commit 2ce7ec1c authored by Yuxin Wu's avatar Yuxin Wu

ScheduleHyperParamSetter based on steps (#633)

parent 0ffdcc44
......@@ -125,7 +125,8 @@ class HyperParamSetter(Callback):
param = GraphVarParam(param)
assert isinstance(param, HyperParam), type(param)
self.param = param
self.last_value = None
self._last_value = None
self._last_epoch_set = -1
def _setup_graph(self):
self.param.setup_graph()
......@@ -141,10 +142,13 @@ class HyperParamSetter(Callback):
set, or return None to do nothing.
"""
ret = self._get_value_to_set()
if ret is not None and ret != self.last_value:
logger.info("After epoch {}, {} will change to {:.8f}".format(
self.epoch_num, self.param.readable_name, ret))
self.last_value = ret
if ret is not None and ret != self._last_value:
if self.epoch_num != self._last_epoch_set:
# Print this message at most once every epoch
logger.info("[HyperParamSetter] At global_step={}, {} will change to {:.8f}".format(
self.global_step, self.param.readable_name, ret))
self._last_epoch_set = self.epoch_num
self._last_value = ret
return ret
@abstractmethod
......@@ -213,7 +217,7 @@ class ScheduledHyperParamSetter(HyperParamSetter):
Set hyperparameters by a predefined epoch-based schedule.
"""
def __init__(self, param, schedule, interp=None):
def __init__(self, param, schedule, interp=None, step_based=False):
"""
Args:
param: same as in :class:`HyperParamSetter`.
......@@ -223,6 +227,8 @@ class ScheduledHyperParamSetter(HyperParamSetter):
If ep == 0, the value will be set before the first epoch
(because by default the first is epoch 1).
interp: None: no interpolation. 'linear': linear interpolation
step_based (bool): interpret ``schedule`` as (step, value) instead
of (epoch, value).
Example:
.. code-block:: python
......@@ -235,28 +241,38 @@ class ScheduledHyperParamSetter(HyperParamSetter):
if interp is not None:
assert interp == 'linear'
self.interp = interp
self._step = step_based
super(ScheduledHyperParamSetter, self).__init__(param)
def _get_value_to_set(self):
refnum = self.global_step if self._step else self.epoch_num
if self.interp is None:
for e, v in self.schedule:
if e == self.epoch_num:
if e == refnum:
return v
return None
else:
laste, lastv = None, None
for e, v in self.schedule:
if e == self.epoch_num:
if e == refnum:
return v
if e > self.epoch_num:
if e > refnum:
break
laste, lastv = e, v
if laste is None or laste == e:
# hasn't reached the first scheduled point, or reached the end of all scheduled points
return None
v = (self.epoch_num - laste) * 1. / (e - laste) * (v - lastv) + lastv
v = (refnum - laste) * 1. / (e - laste) * (v - lastv) + lastv
return v
def _trigger_epoch(self):
if not self._step:
self.trigger()
def _trigger_step(self):
if self._step:
self.trigger()
class HyperParamSetterWithFunc(HyperParamSetter):
""" Set the parameter by a function of epoch num and old value. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment