Commit 6e562112 authored by Yuxin Wu's avatar Yuxin Wu

param setter object attr

parent 704bee73
......@@ -46,7 +46,8 @@ class Callback(object):
"""
self.trainer = trainer
self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch
self.epoch_num = self.trainer.config.starting_epoch - 1
# self.epoch_num is always the number of epochs that finished updating parameters.
self._setup_graph()
def _setup_graph(self):
......@@ -81,8 +82,8 @@ class Callback(object):
In this function, self.epoch_num would be the number of epoch finished.
"""
self._trigger_epoch()
self.epoch_num += 1
self._trigger_epoch()
def _trigger_epoch(self):
pass
......@@ -117,7 +118,7 @@ class PeriodicCallback(ProxyCallback):
self.period = int(period)
def _trigger_epoch(self):
self.cb.epoch_num = self.epoch_num - 1
if self.epoch_num % self.period == 0:
self.cb.epoch_num = self.epoch_num - 1
self.cb.trigger_epoch()
......@@ -20,17 +20,26 @@ class HyperParamSetter(Callback):
"""
__metaclass__ = ABCMeta
# TODO maybe support InputVar?
def __init__(self, var_name, shape=[]):
TF_VAR = 0
OBJ_ATTR = 1
def __init__(self, param, shape=[]):
"""
:param var_name: name of the variable
:param shape: shape of the variable
:param param: either a name of the variable in the graph, or a (object, attribute) tuple
:param shape: shape of the param
"""
self.op_name, self.var_name = get_op_var_name(var_name)
if isinstance(param, tuple):
self.param_type = HyperParamSetter.OBJ_ATTR
self.obj_attr = param
self.readable_name = param[1]
else:
self.param_type = HyperParamSetter.TF_VAR
self.readable_name, self.var_name = get_op_var_name(param)
self.shape = shape
self.last_value = None
def _setup_graph(self):
if self.param_type == HyperParamSetter.TF_VAR:
all_vars = tf.all_variables()
for v in all_vars:
if v.name == self.var_name:
......@@ -40,7 +49,7 @@ class HyperParamSetter(Callback):
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))
self.val_holder = tf.placeholder(tf.float32, shape=self.shape,
name=self.op_name + '_feed')
name=self.readable_name + '_feed')
self.assign_op = self.var.assign(self.val_holder)
def get_current_value(self):
......@@ -50,7 +59,7 @@ class HyperParamSetter(Callback):
ret = self._get_current_value()
if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} will change to {}".format(
self.op_name, self.epoch_num + 1, ret))
self.readable_name, self.epoch_num + 1, ret))
self.last_value = ret
return ret
......@@ -67,19 +76,21 @@ class HyperParamSetter(Callback):
def _set_param(self):
v = self.get_current_value()
if v is not None:
if self.param_type == HyperParamSetter.TF_VAR:
self.assign_op.eval(feed_dict={self.val_holder:v})
else:
setattr(self.obj_attr[0], self.obj_attr[1], v)
class HumanHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters manually by modifying a file.
"""
def __init__(self, var_name, file_name):
def __init__(self, param, file_name):
"""
:param var_name: name of the variable.
:param file_name: a file containing the value of the variable. Each line in the file is a k:v pair
"""
self.file_name = file_name
super(HumanHyperParamSetter, self).__init__(var_name)
super(HumanHyperParamSetter, self).__init__(param)
def _get_current_value(self):
try:
......@@ -87,25 +98,25 @@ class HumanHyperParamSetter(HyperParamSetter):
lines = f.readlines()
lines = [s.strip().split(':') for s in lines]
dic = {str(k):float(v) for k, v in lines}
ret = dic[self.op_name]
ret = dic[self.readable_name]
return ret
except:
logger.warn(
"Failed to parse {} in {}".format(
self.op_name, self.file_name))
self.readable_name, self.file_name))
return None
class ScheduledHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by a predefined schedule.
"""
def __init__(self, var_name, schedule):
def __init__(self, param, schedule):
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
"""
schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0))
super(ScheduledHyperParamSetter, self).__init__(var_name)
super(ScheduledHyperParamSetter, self).__init__(param)
def _get_current_value(self):
for e, v in self.schedule:
......
......@@ -102,7 +102,7 @@ class PredictWorker(multiprocessing.Process):
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
......@@ -115,10 +115,13 @@ class PredictWorker(multiprocessing.Process):
self.config = config
def run(self):
logger.info("Worker {} use GPU {}".format(self.idx, self.gpuid))
if self.gpuid >= 0:
logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
else:
logger.info("Worker {} uses CPU".format(self.idx))
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0'):
with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'):
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment