Commit 0dbfe237 authored by Yuxin Wu's avatar Yuxin Wu

hyperparam setter

parent 9387c653
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: param.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from abc import abstractmethod, ABCMeta
from .base import Callback
from ..utils import logger, get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter']
class HyperParamSetter(Callback):
__metaclass__ = ABCMeta
# TODO maybe support InputVar?
def __init__(self, var_name, shape=[]):
self.op_name, self.var_name = get_op_var_name(var_name)
self.shape = shape
self.last_value = None
def _before_train(self):
all_vars = tf.all_variables()
for v in all_vars:
print v.name
if v.name == self.var_name:
self.var = v
break
else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))
self.val_holder = tf.placeholder(tf.float32, shape=self.shape,
name=self.op_name + '_feed')
self.assign_op = self.var.assign(self.val_holder)
def get_current_value(self):
ret = self._get_current_value()
if ret != self.last_value:
logger.info("{} at epoch {} is changed to {}".format(
self.var_name, self.epoch_num, ret))
self.last_value = ret
return ret
@abstractmethod
def _get_current_value(self):
pass
def _trigger_epoch(self):
v = self.get_current_value()
self.assign_op.eval(feed_dict={self.val_holder:v})
class HumanHyperParamSetter(HyperParamSetter):
def __init__(self, var_name, file_name):
"""
read value from file_name.
file_name: each line in the file is a k:v pair
"""
self.file_name = file_name
super(HumanHyperParamSetter, self).__init__(var_name)
def _get_current_value(self):
with open(self.file_name) as f:
lines = f.readlines()
lines = [s.strip().split(':') for s in lines]
dic = {str(k):float(v) for k, v in lines}
return dic[self.op_name]
......@@ -10,7 +10,8 @@ import numpy as np
from . import logger
__all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized']
__all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized',
'get_op_var_name']
#def expand_dim_if_necessary(var, dp):
# """
......@@ -77,3 +78,9 @@ class memoized(object):
def get_rng(self):
seed = (id(self) + os.getpid()) % 4294967295
return np.random.RandomState(seed)
def get_op_var_name(name):
if name.endswith(':0'):
return name[:-2], name
else:
return name, name + ':0'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment