Commit 772af23d authored by Yuxin Wu's avatar Yuxin Wu

use hyper.txt from LOG_DIR to avoid conflict

parent 7a3e4c4d
...@@ -29,9 +29,9 @@ class HyperParam(object): ...@@ -29,9 +29,9 @@ class HyperParam(object):
""" define how the value of the param will be set""" """ define how the value of the param will be set"""
pass pass
@abstractproperty
def readable_name(self): def readable_name(self):
pass """ A name to display"""
return self._readable_name
class GraphVarParam(HyperParam): class GraphVarParam(HyperParam):
""" a variable in the graph""" """ a variable in the graph"""
...@@ -56,23 +56,20 @@ class GraphVarParam(HyperParam): ...@@ -56,23 +56,20 @@ class GraphVarParam(HyperParam):
def set_value(self, v): def set_value(self, v):
self.assign_op.eval(feed_dict={self.val_holder:v}) self.assign_op.eval(feed_dict={self.val_holder:v})
@property
def readable_name(self):
return self._readable_name
class ObjAttrParam(HyperParam): class ObjAttrParam(HyperParam):
""" an attribute of an object""" """ an attribute of an object"""
def __init__(self, obj, attrname): def __init__(self, obj, attrname, readable_name=None):
""" :param readable_name: default to be attrname."""
self.obj = obj self.obj = obj
self.attrname = attrname self.attrname = attrname
if readable_name is None:
self._readable_name = attrname
else:
self._readable_name = readable_name
def set_value(self, v): def set_value(self, v):
setattr(self.obj, self.attrname, v) setattr(self.obj, self.attrname, v)
@property
def readable_name(self):
return self.attrname
class HyperParamSetter(Callback): class HyperParamSetter(Callback):
""" """
Base class to set hyperparameters after every epoch. Base class to set hyperparameters after every epoch.
...@@ -121,16 +118,20 @@ class HyperParamSetter(Callback): ...@@ -121,16 +118,20 @@ class HyperParamSetter(Callback):
class HumanHyperParamSetter(HyperParamSetter): class HumanHyperParamSetter(HyperParamSetter):
""" """
Set hyperparameters manually by modifying a file. Set hyperparameters by loading the value from a file each time it get called.
""" """
def __init__(self, param, file_name): def __init__(self, param, file_name):
""" """
:param file_name: a file containing the value of the variable. Each line in the file is a k:v pair :param file_name: a file containing the value of the variable.
Each line in the file is a k:v pair, where k is
param.readable_name, and v is the value
""" """
self.file_name = file_name
super(HumanHyperParamSetter, self).__init__(param) super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name)
logger.info("Use {} for hyperparam {}.".format(
self.file_name, self.param.readable_name))
def _get_value_to_set(self): def _get_value_to_set(self):
try: try:
with open(self.file_name) as f: with open(self.file_name) as f:
lines = f.readlines() lines = f.readlines()
...@@ -151,6 +152,7 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -151,6 +152,7 @@ class ScheduledHyperParamSetter(HyperParamSetter):
def __init__(self, param, schedule): def __init__(self, param, schedule):
""" """
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...] :param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
The value is fixed to val1 in epoch [epoch1, epoch2), and so on.
""" """
schedule = [(int(a), float(b)) for a, b in schedule] schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0)) self.schedule = sorted(schedule, key=operator.itemgetter(0))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment