Commit 2074fd50 authored by Yuxin Wu's avatar Yuxin Wu

linear interpolate schedule

parent cee79998
......@@ -161,20 +161,38 @@ class ScheduledHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by a predefined schedule.
"""
def __init__(self, param, schedule):
def __init__(self, param, schedule, interp=None):
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
The value is fixed to val1 in epoch [epoch1, epoch2), and so on.
:param interp: None: no interpolation. 'linear': linear interpolation
"""
schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0))
if interp is not None:
assert interp == 'linear'
self.interp = interp
super(ScheduledHyperParamSetter, self).__init__(param)
def _get_value_to_set(self):
for e, v in self.schedule:
if e == self.epoch_num:
return v
return None
if self.interp is None:
for e, v in self.schedule:
if e == self.epoch_num:
return v
return None
else:
laste, lastv = None, None
for e, v in self.schedule:
if e == self.epoch_num:
return v
if e > self.epoch_num:
break
laste, lastv = e, v
if laste is None or laste == e:
# hasn't reached the first scheduled point, or reached the end of all scheduled points
return None
v = (self.epoch_num - laste) * 1. / (e - laste) * (v - lastv) + lastv
return v
class StatMonitorParamSetter(HyperParamSetter):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment