Commit 11ca8b2c authored by Yuxin Wu's avatar Yuxin Wu

ScheduleParam: set at beginning

parent 36bdc187
...@@ -234,7 +234,8 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -234,7 +234,8 @@ class ScheduledHyperParamSetter(HyperParamSetter):
Set hyperparameters by a predefined epoch-based schedule. Set hyperparameters by a predefined epoch-based schedule.
""" """
def __init__(self, param, schedule, interp=None, step_based=False): def __init__(self, param, schedule, interp=None, step_based=False,
set_at_beginning=True):
""" """
Args: Args:
param: same as in :class:`HyperParamSetter`. param: same as in :class:`HyperParamSetter`.
...@@ -250,6 +251,14 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -250,6 +251,14 @@ class ScheduledHyperParamSetter(HyperParamSetter):
every time this callback is triggered. every time this callback is triggered.
step_based (bool): interpret ``schedule`` as (step, value) instead step_based (bool): interpret ``schedule`` as (step, value) instead
of (epoch, value). of (epoch, value).
set_at_beginning (bool): at the start of training, the current value
may be different from the expected value according to the
schedule.
If this option is True, set the value anyway even though the current
epoch/step is not at the scheduled time.
If False, the value will only be set according to the
schedule, i.e. it will only be set if the current epoch/step
is at the scheduled time.
Example: Example:
.. code-block:: python .. code-block:: python
...@@ -263,6 +272,7 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -263,6 +272,7 @@ class ScheduledHyperParamSetter(HyperParamSetter):
assert interp == 'linear' assert interp == 'linear'
self.interp = interp self.interp = interp
self._step = step_based self._step = step_based
self._set_at_beginning = set_at_beginning
super(ScheduledHyperParamSetter, self).__init__(param) super(ScheduledHyperParamSetter, self).__init__(param)
def _get_value_to_set(self): # override parent def _get_value_to_set(self): # override parent
...@@ -277,12 +287,17 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -277,12 +287,17 @@ class ScheduledHyperParamSetter(HyperParamSetter):
for p in range(0, self._current_point() + 1): for p in range(0, self._current_point() + 1):
v = self._get_value_to_set_at_point(p) or v v = self._get_value_to_set_at_point(p) or v
actual_value = self.param.get_value() actual_value = self.param.get_value()
current_point = "step" if self._step else "epoch" + str(self._current_point())
if v is not None and not np.isclose(v, actual_value): if v is not None and not np.isclose(v, actual_value):
logger.warn("According to scheduler {}, parameter '{}' should become {} at the current point. " logger.warn("According to scheduler {}, parameter '{}' should become {:.7g} at the current point ({}). "
"However its current value is {}. " "However its current value is {:.7g}. ".format(
"If this is the only scheduler being used, you may want to check whether your " self, self.param.readable_name, v, current_point, actual_value))
"initialization of the parameter is as expected".format( if self._set_at_beginning:
self, self.param.readable_name, v, actual_value)) logger.info("Setting '{}' to {:.7g}.".format(self.param.readable_name, v))
self.param.set_value(v)
else:
logger.warn("If there is no other scheduler being used, you may want to check whether your "
"initialization of the parameter is as expected")
def _get_value_to_set_at_point(self, point): def _get_value_to_set_at_point(self, point):
""" """
......
...@@ -69,10 +69,17 @@ class ScheduledHyperParamSetterTest(unittest.TestCase): ...@@ -69,10 +69,17 @@ class ScheduledHyperParamSetterTest(unittest.TestCase):
def testStartAfterSchedule(self): def testStartAfterSchedule(self):
scheduler = ScheduledHyperParamSetter( scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME), ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)]) [(10, 0.3), (20, 0.4), (30, 0.5)], set_at_beginning=False)
history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90) history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
self.assertEqual(len(history), 0) self.assertEqual(len(history), 0)
def testStartAfterSchedule_SetAtBeginning(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)], set_at_beginning=True)
history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
self.assertEqual(history, {0: 0.5})
def testWarningStartInTheMiddle(self): def testWarningStartInTheMiddle(self):
scheduler = ScheduledHyperParamSetter( scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME), ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment