Commit 4f529aed authored by Fang Zhang's avatar Fang Zhang Committed by Yuxin Wu

Use last k observations instead of last k epochs. (#928)

* Use last k observations instead of last k epochs.

* rewriting in a more concise way
parent cd0f600b
...@@ -160,7 +160,7 @@ class HyperParamSetter(Callback): ...@@ -160,7 +160,7 @@ class HyperParamSetter(Callback):
logger.info("[HyperParamSetter] At global_step={}, {} changes from {:.6f} to {:.6f}".format( logger.info("[HyperParamSetter] At global_step={}, {} changes from {:.6f} to {:.6f}".format(
self.global_step, self.param.readable_name, self._last_value, ret)) self.global_step, self.param.readable_name, self._last_value, ret))
self._last_epoch_set = self.epoch_num self._last_epoch_set = self.epoch_num
self._last_value = ret self._last_value = ret
return ret return ret
@abstractmethod @abstractmethod
...@@ -363,8 +363,6 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -363,8 +363,6 @@ class StatMonitorParamSetter(HyperParamSetter):
self.threshold = threshold self.threshold = threshold
self.reverse = reverse self.reverse = reverse
self.last_changed_epoch = 0
def _get_value_to_set(self): def _get_value_to_set(self):
try: try:
last = self.trainer.monitors.get_history(self.stat_name)[-1] last = self.trainer.monitors.get_history(self.stat_name)[-1]
...@@ -378,9 +376,7 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -378,9 +376,7 @@ class StatMonitorParamSetter(HyperParamSetter):
self.history.append(last) self.history.append(last)
if len(self.history) < self.history.maxlen or \ if len(self.history) < self.history.maxlen:
self.epoch_num - self.last_changed_epoch < self.history.maxlen:
# not full yet, or value have changed just now
return None return None
values = [k[1] for k in self.history] values = [k[1] for k in self.history]
...@@ -393,7 +389,7 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -393,7 +389,7 @@ class StatMonitorParamSetter(HyperParamSetter):
hist_max = max(values) hist_max = max(values)
if hist_max > hist_first + self.threshold: # large enough if hist_max > hist_first + self.threshold: # large enough
return None return None
self.last_changed_epoch = self.epoch_num self.history.clear()
logger.info( logger.info(
"[StatMonitorParamSetter] Triggered, history of {}: ".format( "[StatMonitorParamSetter] Triggered, history of {}: ".format(
self.stat_name) + ','.join([str(round(x, 3)) for x in values])) self.stat_name) + ','.join([str(round(x, 3)) for x in values]))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment