Commit 5aab2d2d authored by Yuxin Wu's avatar Yuxin Wu

use statmonitor similar to PVANet

parent 5deebdcb
......@@ -188,7 +188,7 @@ class ClassificationError(Inferencer):
testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch.
"""
def __init__(self, wrong_var_name='wrong:0', summary_name='validation_error'):
def __init__(self, wrong_var_name='wrong:0', summary_name='val_error'):
"""
:param wrong_var_name: name of the `wrong` variable
:param summary_name: the name for logging
......
......@@ -199,18 +199,16 @@ class ScheduledHyperParamSetter(HyperParamSetter):
class StatMonitorParamSetter(HyperParamSetter):
"""
Set hyperparameter by a func, if a specific stat wasn't
monotonically decreasing/increasing $a$ times out of the last $b$ epochs
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs
"""
def __init__(self, param, stat_name, value_func,
last_k=5,
min_non_decreasing=2,
reverse=False
def __init__(self, param, stat_name, value_func, threshold,
last_k, reverse=False
):
"""
Change param by `new_value = value_func(old_value)`,
if `stat_name` wasn't decreasing >=2 times in the lastest 5 times of
statistics update.
if `stat_name` wasn't decreasing > threshold times in the lastest
last_k times of statistics update.
For example, if error wasn't decreasing, anneal the learning rate:
StatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
......@@ -221,13 +219,10 @@ class StatMonitorParamSetter(HyperParamSetter):
self.stat_name = stat_name
self.value_func = value_func
self.last_k = last_k
self.min_non_decreasing = min_non_decreasing
self.last_changed_epoch = 0
self.threshold = threshold
self.reverse = reverse
if not reverse:
self.less_than = lambda x, y: x <= y
else:
self.less_than = lambda x, y: x >= y
self.last_changed_epoch = 0
def _get_value_to_set(self):
holder = self.trainer.stat_holder
......@@ -236,13 +231,16 @@ class StatMonitorParamSetter(HyperParamSetter):
self.epoch_num - self.last_changed_epoch < self.last_k:
return None
hist = hist[-self.last_k-1:] # len==last_k+1
cnt = 0
for k in range(self.last_k):
if self.less_than(hist[k], hist[k+1]):
cnt += 1
if cnt >= self.min_non_decreasing \
and self.less_than(hist[0], hist[-1]):
self.last_changed_epoch = self.epoch_num
return self.value_func(self.get_current_value())
return None
hist_first = hist[0]
if not self.reverse:
hist_min = min(hist)
if hist_min <= hist_first - self.threshold: # small enough
return None
else:
hist_max = max(hist)
if hist_max >= hist_first + self.threshold: # large enough
return None
self.last_changed_epoch = self.epoch_num
return self.value_func(self.get_current_value())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment