Commit 772af23d authored by Yuxin Wu's avatar Yuxin Wu

use hyper.txt from LOG_DIR to avoid conflict

parent 7a3e4c4d
......@@ -29,9 +29,9 @@ class HyperParam(object):
""" define how the value of the param will be set"""
pass
@abstractproperty
def readable_name(self):
pass
""" A name to display"""
return self._readable_name
class GraphVarParam(HyperParam):
""" a variable in the graph"""
......@@ -56,23 +56,20 @@ class GraphVarParam(HyperParam):
def set_value(self, v):
self.assign_op.eval(feed_dict={self.val_holder:v})
@property
def readable_name(self):
return self._readable_name
class ObjAttrParam(HyperParam):
""" an attribute of an object"""
def __init__(self, obj, attrname):
def __init__(self, obj, attrname, readable_name=None):
""" :param readable_name: default to be attrname."""
self.obj = obj
self.attrname = attrname
if readable_name is None:
self._readable_name = attrname
else:
self._readable_name = readable_name
def set_value(self, v):
setattr(self.obj, self.attrname, v)
@property
def readable_name(self):
return self.attrname
class HyperParamSetter(Callback):
"""
Base class to set hyperparameters after every epoch.
......@@ -121,16 +118,20 @@ class HyperParamSetter(Callback):
class HumanHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters manually by modifying a file.
Set hyperparameters by loading the value from a file each time it get called.
"""
def __init__(self, param, file_name):
"""
:param file_name: a file containing the value of the variable. Each line in the file is a k:v pair
:param file_name: a file containing the value of the variable.
Each line in the file is a k:v pair, where k is
param.readable_name, and v is the value
"""
self.file_name = file_name
super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name)
logger.info("Use {} for hyperparam {}.".format(
self.file_name, self.param.readable_name))
def _get_value_to_set(self):
def _get_value_to_set(self):
try:
with open(self.file_name) as f:
lines = f.readlines()
......@@ -151,6 +152,7 @@ class ScheduledHyperParamSetter(HyperParamSetter):
def __init__(self, param, schedule):
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
The value is fixed to val1 in epoch [epoch1, epoch2), and so on.
"""
schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment