Commit 509c2c90 authored by Yuxin Wu's avatar Yuxin Wu

non-decr stat monitor param setter

parent 838a4ba3
......@@ -87,7 +87,7 @@ def get_config():
StatPrinter(),
ModelSaver(),
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError() ])
[ScalarStats('cost'), ClassificationError() ]),
]),
session_config=get_default_sess_config(0.5),
model=Model(),
......
......@@ -15,6 +15,7 @@ from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter',
'NonDecreasingStatMonitorParamSetter',
'HyperParam', 'GraphVarParam', 'ObjAttrParam']
class HyperParam(object):
......@@ -36,7 +37,7 @@ class HyperParam(object):
return self._readable_name
class GraphVarParam(HyperParam):
""" a variable in the graph"""
""" a variable in the graph can be a hyperparam"""
def __init__(self, name, shape=[]):
self.name = name
self.shape = shape
......@@ -58,8 +59,11 @@ class GraphVarParam(HyperParam):
def set_value(self, v):
self.assign_op.eval(feed_dict={self.val_holder:v})
def get_value(self):
return self.var.eval()
class ObjAttrParam(HyperParam):
""" an attribute of an object"""
""" an attribute of an object can be a hyperparam"""
def __init__(self, obj, attrname, readable_name=None):
""" :param readable_name: default to be attrname."""
self.obj = obj
......@@ -72,6 +76,9 @@ class ObjAttrParam(HyperParam):
def set_value(self, v):
setattr(self.obj, self.attrname, v)
def get_value(self, v):
return getattr(self.obj, self.attrname)
class HyperParamSetter(Callback):
"""
Base class to set hyperparameters after every epoch.
......@@ -98,11 +105,14 @@ class HyperParamSetter(Callback):
"""
ret = self._get_value_to_set()
if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} will change to {}".format(
logger.info("{} at epoch {} will change to {:.8f}".format(
self.param.readable_name, self.epoch_num + 1, ret))
self.last_value = ret
return ret
def get_current_value(self):
return self.param.get_value()
@abstractmethod
def _get_value_to_set(self):
pass
......@@ -166,3 +176,43 @@ class ScheduledHyperParamSetter(HyperParamSetter):
return v
return None
class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
"""
Set hyperparameter by a func, if a specific stat wasn't
monotonically decreasing $a$ times out of the last $b$ epochs
"""
def __init__(self, param, stat_name, value_func,
last_k=5,
min_non_decreasing=2
):
"""
Change param by `new_value = value_func(old_value)`,
if `stat_name` wasn't decreasing >=2 times in the lastest 5 times of
statistics update.
For example, if error wasn't decreasing, anneal the learning rate:
NonDecreasingStatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
"""
super(NonDecreasingStatMonitorParamSetter, self).__init__(param)
self.stat_name = stat_name
self.value_func = value_func
self.last_k = last_k
self.min_non_decreasing = min_non_decreasing
self.last_changed_epoch = 0
def _get_value_to_set(self):
holder = self.trainer.stat_holder
hist = holder.get_stat_history(self.stat_name)
if len(hist) < self.last_k+1 or \
self.epoch_num - self.last_changed_epoch < self.last_k:
return None
hist = hist[-self.last_k-1:] # len==last_k+1
cnt = 0
for k in range(self.last_k):
if hist[k] <= hist[k+1]:
cnt += 1
if cnt >= self.min_non_decreasing \
and hist[-1] >= hist[0]:
return self.value_func(self.get_current_value())
return None
......@@ -57,6 +57,15 @@ class StatHolder(object):
"""
return self.stat_now[key]
def get_stat_history(self, key):
ret = []
for h in self.stat_history:
v = h.get(key, None)
if v is not None: ret.append(v)
v = self.stat_now.get(key, None)
if v is not None: ret.append(v)
return ret
def finalize(self):
"""
Called after finishing adding stats. Will print and write stats to disk.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment