Commit 5aab2d2d authored by Yuxin Wu's avatar Yuxin Wu

use statmonitor similar to PVANet

parent 5deebdcb
...@@ -188,7 +188,7 @@ class ClassificationError(Inferencer): ...@@ -188,7 +188,7 @@ class ClassificationError(Inferencer):
testing (because the size of test set might not be a multiple of batch size). testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch. Therefore the result is different from averaging the error rate of each batch.
""" """
def __init__(self, wrong_var_name='wrong:0', summary_name='validation_error'): def __init__(self, wrong_var_name='wrong:0', summary_name='val_error'):
""" """
:param wrong_var_name: name of the `wrong` variable :param wrong_var_name: name of the `wrong` variable
:param summary_name: the name for logging :param summary_name: the name for logging
......
...@@ -199,18 +199,16 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -199,18 +199,16 @@ class ScheduledHyperParamSetter(HyperParamSetter):
class StatMonitorParamSetter(HyperParamSetter): class StatMonitorParamSetter(HyperParamSetter):
""" """
Set hyperparameter by a func, if a specific stat wasn't Set hyperparameter by a func, when a specific stat wasn't
monotonically decreasing/increasing $a$ times out of the last $b$ epochs decreasing/increasing enough in the last $k$ epochs
""" """
def __init__(self, param, stat_name, value_func, def __init__(self, param, stat_name, value_func, threshold,
last_k=5, last_k, reverse=False
min_non_decreasing=2,
reverse=False
): ):
""" """
Change param by `new_value = value_func(old_value)`, Change param by `new_value = value_func(old_value)`,
if `stat_name` wasn't decreasing >=2 times in the lastest 5 times of if `stat_name` wasn't decreasing > threshold times in the lastest
statistics update. last_k times of statistics update.
For example, if error wasn't decreasing, anneal the learning rate: For example, if error wasn't decreasing, anneal the learning rate:
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2) StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
...@@ -221,13 +219,10 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -221,13 +219,10 @@ class StatMonitorParamSetter(HyperParamSetter):
self.stat_name = stat_name self.stat_name = stat_name
self.value_func = value_func self.value_func = value_func
self.last_k = last_k self.last_k = last_k
self.min_non_decreasing = min_non_decreasing self.threshold = threshold
self.last_changed_epoch = 0 self.reverse = reverse
if not reverse: self.last_changed_epoch = 0
self.less_than = lambda x, y: x <= y
else:
self.less_than = lambda x, y: x >= y
def _get_value_to_set(self): def _get_value_to_set(self):
holder = self.trainer.stat_holder holder = self.trainer.stat_holder
...@@ -236,13 +231,16 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -236,13 +231,16 @@ class StatMonitorParamSetter(HyperParamSetter):
self.epoch_num - self.last_changed_epoch < self.last_k: self.epoch_num - self.last_changed_epoch < self.last_k:
return None return None
hist = hist[-self.last_k-1:] # len==last_k+1 hist = hist[-self.last_k-1:] # len==last_k+1
cnt = 0
for k in range(self.last_k): hist_first = hist[0]
if self.less_than(hist[k], hist[k+1]): if not self.reverse:
cnt += 1 hist_min = min(hist)
if cnt >= self.min_non_decreasing \ if hist_min <= hist_first - self.threshold: # small enough
and self.less_than(hist[0], hist[-1]): return None
else:
hist_max = max(hist)
if hist_max >= hist_first + self.threshold: # large enough
return None
self.last_changed_epoch = self.epoch_num self.last_changed_epoch = self.epoch_num
return self.value_func(self.get_current_value()) return self.value_func(self.get_current_value())
return None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment