Commit 2d720b60 authored by Yuxin Wu's avatar Yuxin Wu

schedule hyper param setter

parent 0dbfe237
......@@ -147,11 +147,7 @@ def get_config():
sess_config = get_default_sess_config(0.9)
lr = tf.train.exponential_decay(
learning_rate=1e-1,
global_step=get_global_step_var(),
decay_steps=36000,
decay_rate=0.1, staircase=True, name='learning_rate')
lr = tf.Variable(0.1, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
......@@ -161,6 +157,8 @@ def get_config():
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate',
[(82, 0.01), (123, 0.001), (300, 0.0001)])
]),
session_config=sess_config,
model=Model(n=18),
......
......@@ -5,10 +5,13 @@
import tensorflow as tf
from abc import abstractmethod, ABCMeta
import operator
from .base import Callback
from ..utils import logger, get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter']
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter']
class HyperParamSetter(Callback):
__metaclass__ = ABCMeta
......@@ -35,9 +38,9 @@ class HyperParamSetter(Callback):
def get_current_value(self):
ret = self._get_current_value()
if ret != self.last_value:
if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} is changed to {}".format(
self.var_name, self.epoch_num, ret))
self.op_name, self.epoch_num, ret))
self.last_value = ret
return ret
......@@ -47,6 +50,7 @@ class HyperParamSetter(Callback):
def _trigger_epoch(self):
v = self.get_current_value()
if v is not None:
self.assign_op.eval(feed_dict={self.val_holder:v})
class HumanHyperParamSetter(HyperParamSetter):
......@@ -64,3 +68,20 @@ class HumanHyperParamSetter(HyperParamSetter):
lines = [s.strip().split(':') for s in lines]
dic = {str(k):float(v) for k, v in lines}
return dic[self.op_name]
class ScheduledHyperParamSetter(HyperParamSetter):
def __init__(self, var_name, schedule):
"""
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
"""
self.schedule = sorted(schedule, key=operator.itemgetter(0))
super(ScheduledHyperParamSetter, self).__init__(var_name)
def _get_current_value(self):
for e, v in self.schedule:
if e == self.epoch_num:
return v
return None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment