Commit 11ca8b2c authored by Yuxin Wu's avatar Yuxin Wu

ScheduleParam: set at beginning

parent 36bdc187
......@@ -234,7 +234,8 @@ class ScheduledHyperParamSetter(HyperParamSetter):
Set hyperparameters by a predefined epoch-based schedule.
"""
def __init__(self, param, schedule, interp=None, step_based=False):
def __init__(self, param, schedule, interp=None, step_based=False,
set_at_beginning=True):
"""
Args:
param: same as in :class:`HyperParamSetter`.
......@@ -250,6 +251,14 @@ class ScheduledHyperParamSetter(HyperParamSetter):
every time this callback is triggered.
step_based (bool): interpret ``schedule`` as (step, value) instead
of (epoch, value).
set_at_beginning (bool): at the start of training, the current value
may be different from the expected value according to the
schedule.
If this option is True, set the value anyway even though the current
epoch/step is not at the scheduled time.
If False, the value will only be set according to the
schedule, i.e. it will only be set if the current epoch/step
is at the scheduled time.
Example:
.. code-block:: python
......@@ -263,6 +272,7 @@ class ScheduledHyperParamSetter(HyperParamSetter):
assert interp == 'linear'
self.interp = interp
self._step = step_based
self._set_at_beginning = set_at_beginning
super(ScheduledHyperParamSetter, self).__init__(param)
def _get_value_to_set(self): # override parent
......@@ -277,12 +287,17 @@ class ScheduledHyperParamSetter(HyperParamSetter):
for p in range(0, self._current_point() + 1):
v = self._get_value_to_set_at_point(p) or v
actual_value = self.param.get_value()
current_point = "step" if self._step else "epoch" + str(self._current_point())
if v is not None and not np.isclose(v, actual_value):
logger.warn("According to scheduler {}, parameter '{}' should become {} at the current point. "
"However its current value is {}. "
"If this is the only scheduler being used, you may want to check whether your "
"initialization of the parameter is as expected".format(
self, self.param.readable_name, v, actual_value))
logger.warn("According to scheduler {}, parameter '{}' should become {:.7g} at the current point ({}). "
"However its current value is {:.7g}. ".format(
self, self.param.readable_name, v, current_point, actual_value))
if self._set_at_beginning:
logger.info("Setting '{}' to {:.7g}.".format(self.param.readable_name, v))
self.param.set_value(v)
else:
logger.warn("If there is no other scheduler being used, you may want to check whether your "
"initialization of the parameter is as expected")
def _get_value_to_set_at_point(self, point):
"""
......
......@@ -69,10 +69,17 @@ class ScheduledHyperParamSetterTest(unittest.TestCase):
def testStartAfterSchedule(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)])
[(10, 0.3), (20, 0.4), (30, 0.5)], set_at_beginning=False)
history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
self.assertEqual(len(history), 0)
def testStartAfterSchedule_SetAtBeginning(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)], set_at_beginning=True)
history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
self.assertEqual(history, {0: 0.5})
def testWarningStartInTheMiddle(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment