Commit 2d720b60 authored by Yuxin Wu's avatar Yuxin Wu

schedule hyper param setter

parent 0dbfe237
...@@ -147,11 +147,7 @@ def get_config(): ...@@ -147,11 +147,7 @@ def get_config():
sess_config = get_default_sess_config(0.9) sess_config = get_default_sess_config(0.9)
lr = tf.train.exponential_decay( lr = tf.Variable(0.1, trainable=False, name='learning_rate')
learning_rate=1e-1,
global_step=get_global_step_var(),
decay_steps=36000,
decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
...@@ -161,6 +157,8 @@ def get_config(): ...@@ -161,6 +157,8 @@ def get_config():
StatPrinter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate',
[(82, 0.01), (123, 0.001), (300, 0.0001)])
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(n=18), model=Model(n=18),
......
...@@ -5,10 +5,13 @@ ...@@ -5,10 +5,13 @@
import tensorflow as tf import tensorflow as tf
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import operator
from .base import Callback from .base import Callback
from ..utils import logger, get_op_var_name from ..utils import logger, get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter'] __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter']
class HyperParamSetter(Callback): class HyperParamSetter(Callback):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -35,9 +38,9 @@ class HyperParamSetter(Callback): ...@@ -35,9 +38,9 @@ class HyperParamSetter(Callback):
def get_current_value(self): def get_current_value(self):
ret = self._get_current_value() ret = self._get_current_value()
if ret != self.last_value: if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} is changed to {}".format( logger.info("{} at epoch {} is changed to {}".format(
self.var_name, self.epoch_num, ret)) self.op_name, self.epoch_num, ret))
self.last_value = ret self.last_value = ret
return ret return ret
...@@ -47,7 +50,8 @@ class HyperParamSetter(Callback): ...@@ -47,7 +50,8 @@ class HyperParamSetter(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
v = self.get_current_value() v = self.get_current_value()
self.assign_op.eval(feed_dict={self.val_holder:v}) if v is not None:
self.assign_op.eval(feed_dict={self.val_holder:v})
class HumanHyperParamSetter(HyperParamSetter): class HumanHyperParamSetter(HyperParamSetter):
def __init__(self, var_name, file_name): def __init__(self, var_name, file_name):
...@@ -64,3 +68,20 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -64,3 +68,20 @@ class HumanHyperParamSetter(HyperParamSetter):
lines = [s.strip().split(':') for s in lines] lines = [s.strip().split(':') for s in lines]
dic = {str(k):float(v) for k, v in lines} dic = {str(k):float(v) for k, v in lines}
return dic[self.op_name] return dic[self.op_name]
class ScheduledHyperParamSetter(HyperParamSetter):
def __init__(self, var_name, schedule):
"""
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
"""
self.schedule = sorted(schedule, key=operator.itemgetter(0))
super(ScheduledHyperParamSetter, self).__init__(var_name)
def _get_current_value(self):
for e, v in self.schedule:
if e == self.epoch_num:
return v
return None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment