Commit 3af5f874 authored by Yuxin Wu's avatar Yuxin Wu

exception handling in humanparam

parent 6d299170
...@@ -62,17 +62,25 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -62,17 +62,25 @@ class HumanHyperParamSetter(HyperParamSetter):
super(HumanHyperParamSetter, self).__init__(var_name) super(HumanHyperParamSetter, self).__init__(var_name)
def _get_current_value(self): def _get_current_value(self):
with open(self.file_name) as f: try:
lines = f.readlines() with open(self.file_name) as f:
lines = [s.strip().split(':') for s in lines] lines = f.readlines()
dic = {str(k):float(v) for k, v in lines} lines = [s.strip().split(':') for s in lines]
return dic[self.op_name] dic = {str(k):float(v) for k, v in lines}
ret = dic[self.op_name]
return ret
except:
logger.warn(
"Failed to parse {} in {}".format(
self.op_name, self.file_name))
return None
class ScheduledHyperParamSetter(HyperParamSetter): class ScheduledHyperParamSetter(HyperParamSetter):
def __init__(self, var_name, schedule): def __init__(self, var_name, schedule):
""" """
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...] schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
""" """
schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0)) self.schedule = sorted(schedule, key=operator.itemgetter(0))
super(ScheduledHyperParamSetter, self).__init__(var_name) super(ScheduledHyperParamSetter, self).__init__(var_name)
......
...@@ -8,7 +8,7 @@ from ...utils.fs import mkdir_p, download ...@@ -8,7 +8,7 @@ from ...utils.fs import mkdir_p, download
__all__ = ['ILSVRCMeta'] __all__ = ['ILSVRCMeta']
CAFFE_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class ILSVRCMeta(object): class ILSVRCMeta(object):
def __init__(self, dir=None): def __init__(self, dir=None):
...@@ -26,7 +26,7 @@ class ILSVRCMeta(object): ...@@ -26,7 +26,7 @@ class ILSVRCMeta(object):
return dict(enumerate(lines)) return dict(enumerate(lines))
def download_caffe_meta(self): def download_caffe_meta(self):
fpath = download(CAFFE_URL, self.dir) fpath = download(CAFFE_ILSVRC12_URL, self.dir)
tarfile.open(fpath, 'r:gz').extractall(self.dir) tarfile.open(fpath, 'r:gz').extractall(self.dir)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment