Commit 7315d2cc authored by Yuxin Wu's avatar Yuxin Wu

fix scheduled setter to work only at exact reach

parent 833fc5e2
...@@ -225,7 +225,11 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -225,7 +225,11 @@ class ScheduledHyperParamSetter(HyperParamSetter):
to "val" **after** the completion of epoch `ep`. to "val" **after** the completion of epoch `ep`.
If ep == 0, the value will be set before the first epoch If ep == 0, the value will be set before the first epoch
(because by default the first is epoch 1). (because by default the first is epoch 1).
interp: None: no interpolation. 'linear': linear interpolation The epoch numbers have to be increasing.
interp (str or None): Either None or 'linear'.
If None, the parameter will only be set when the specific epoch or steps
is reached exactly. If 'linear', perform linear interpolation (but no extrapolation)
every time this callback is triggered.
step_based (bool): interpret ``schedule`` as (step, value) instead step_based (bool): interpret ``schedule`` as (step, value) instead
of (epoch, value). of (epoch, value).
...@@ -248,17 +252,17 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -248,17 +252,17 @@ class ScheduledHyperParamSetter(HyperParamSetter):
laste, lastv = None, None laste, lastv = None, None
for e, v in self.schedule: for e, v in self.schedule:
if e == refnum: if e == refnum:
return v return v # meet the exact boundary, return directly
if e > refnum: if e > refnum:
break break
laste, lastv = e, v laste, lastv = e, v
if laste is None or laste == e: if laste is None or laste == e:
# hasn't reached the first scheduled point, or reached the end of all scheduled points # hasn't reached the first scheduled point, or reached the end of all scheduled points
return None return None
if self.interp is not None: if self.interp is None:
v = (refnum - laste) * 1. / (e - laste) * (v - lastv) + lastv # If no interpolation, nothing to do.
else: return None
v = lastv v = (refnum - laste) * 1. / (e - laste) * (v - lastv) + lastv
return v return v
def _trigger_epoch(self): def _trigger_epoch(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment