Commit 6e562112 authored by Yuxin Wu's avatar Yuxin Wu

param setter object attr

parent 704bee73
...@@ -46,7 +46,8 @@ class Callback(object): ...@@ -46,7 +46,8 @@ class Callback(object):
""" """
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.epoch_num = self.trainer.config.starting_epoch self.epoch_num = self.trainer.config.starting_epoch - 1
# self.epoch_num is always the number of epochs that finished updating parameters.
self._setup_graph() self._setup_graph()
def _setup_graph(self): def _setup_graph(self):
...@@ -81,8 +82,8 @@ class Callback(object): ...@@ -81,8 +82,8 @@ class Callback(object):
In this function, self.epoch_num would be the number of epoch finished. In this function, self.epoch_num would be the number of epoch finished.
""" """
self._trigger_epoch()
self.epoch_num += 1 self.epoch_num += 1
self._trigger_epoch()
def _trigger_epoch(self): def _trigger_epoch(self):
pass pass
...@@ -117,7 +118,7 @@ class PeriodicCallback(ProxyCallback): ...@@ -117,7 +118,7 @@ class PeriodicCallback(ProxyCallback):
self.period = int(period) self.period = int(period)
def _trigger_epoch(self): def _trigger_epoch(self):
self.cb.epoch_num = self.epoch_num - 1
if self.epoch_num % self.period == 0: if self.epoch_num % self.period == 0:
self.cb.epoch_num = self.epoch_num - 1
self.cb.trigger_epoch() self.cb.trigger_epoch()
...@@ -20,17 +20,26 @@ class HyperParamSetter(Callback): ...@@ -20,17 +20,26 @@ class HyperParamSetter(Callback):
""" """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
# TODO maybe support InputVar? TF_VAR = 0
def __init__(self, var_name, shape=[]): OBJ_ATTR = 1
def __init__(self, param, shape=[]):
""" """
:param var_name: name of the variable :param param: either a name of the variable in the graph, or a (object, attribute) tuple
:param shape: shape of the variable :param shape: shape of the param
""" """
self.op_name, self.var_name = get_op_var_name(var_name) if isinstance(param, tuple):
self.param_type = HyperParamSetter.OBJ_ATTR
self.obj_attr = param
self.readable_name = param[1]
else:
self.param_type = HyperParamSetter.TF_VAR
self.readable_name, self.var_name = get_op_var_name(param)
self.shape = shape self.shape = shape
self.last_value = None self.last_value = None
def _setup_graph(self): def _setup_graph(self):
if self.param_type == HyperParamSetter.TF_VAR:
all_vars = tf.all_variables() all_vars = tf.all_variables()
for v in all_vars: for v in all_vars:
if v.name == self.var_name: if v.name == self.var_name:
...@@ -40,7 +49,7 @@ class HyperParamSetter(Callback): ...@@ -40,7 +49,7 @@ class HyperParamSetter(Callback):
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name)) raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))
self.val_holder = tf.placeholder(tf.float32, shape=self.shape, self.val_holder = tf.placeholder(tf.float32, shape=self.shape,
name=self.op_name + '_feed') name=self.readable_name + '_feed')
self.assign_op = self.var.assign(self.val_holder) self.assign_op = self.var.assign(self.val_holder)
def get_current_value(self): def get_current_value(self):
...@@ -50,7 +59,7 @@ class HyperParamSetter(Callback): ...@@ -50,7 +59,7 @@ class HyperParamSetter(Callback):
ret = self._get_current_value() ret = self._get_current_value()
if ret is not None and ret != self.last_value: if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} will change to {}".format( logger.info("{} at epoch {} will change to {}".format(
self.op_name, self.epoch_num + 1, ret)) self.readable_name, self.epoch_num + 1, ret))
self.last_value = ret self.last_value = ret
return ret return ret
...@@ -67,19 +76,21 @@ class HyperParamSetter(Callback): ...@@ -67,19 +76,21 @@ class HyperParamSetter(Callback):
def _set_param(self): def _set_param(self):
v = self.get_current_value() v = self.get_current_value()
if v is not None: if v is not None:
if self.param_type == HyperParamSetter.TF_VAR:
self.assign_op.eval(feed_dict={self.val_holder:v}) self.assign_op.eval(feed_dict={self.val_holder:v})
else:
setattr(self.obj_attr[0], self.obj_attr[1], v)
class HumanHyperParamSetter(HyperParamSetter): class HumanHyperParamSetter(HyperParamSetter):
""" """
Set hyperparameters manually by modifying a file. Set hyperparameters manually by modifying a file.
""" """
def __init__(self, var_name, file_name): def __init__(self, param, file_name):
""" """
:param var_name: name of the variable.
:param file_name: a file containing the value of the variable. Each line in the file is a k:v pair :param file_name: a file containing the value of the variable. Each line in the file is a k:v pair
""" """
self.file_name = file_name self.file_name = file_name
super(HumanHyperParamSetter, self).__init__(var_name) super(HumanHyperParamSetter, self).__init__(param)
def _get_current_value(self): def _get_current_value(self):
try: try:
...@@ -87,25 +98,25 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -87,25 +98,25 @@ class HumanHyperParamSetter(HyperParamSetter):
lines = f.readlines() lines = f.readlines()
lines = [s.strip().split(':') for s in lines] lines = [s.strip().split(':') for s in lines]
dic = {str(k):float(v) for k, v in lines} dic = {str(k):float(v) for k, v in lines}
ret = dic[self.op_name] ret = dic[self.readable_name]
return ret return ret
except: except:
logger.warn( logger.warn(
"Failed to parse {} in {}".format( "Failed to parse {} in {}".format(
self.op_name, self.file_name)) self.readable_name, self.file_name))
return None return None
class ScheduledHyperParamSetter(HyperParamSetter): class ScheduledHyperParamSetter(HyperParamSetter):
""" """
Set hyperparameters by a predefined schedule. Set hyperparameters by a predefined schedule.
""" """
def __init__(self, var_name, schedule): def __init__(self, param, schedule):
""" """
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...] :param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
""" """
schedule = [(int(a), float(b)) for a, b in schedule] schedule = [(int(a), float(b)) for a, b in schedule]
self.schedule = sorted(schedule, key=operator.itemgetter(0)) self.schedule = sorted(schedule, key=operator.itemgetter(0))
super(ScheduledHyperParamSetter, self).__init__(var_name) super(ScheduledHyperParamSetter, self).__init__(param)
def _get_current_value(self): def _get_current_value(self):
for e, v in self.schedule: for e, v in self.schedule:
......
...@@ -102,7 +102,7 @@ class PredictWorker(multiprocessing.Process): ...@@ -102,7 +102,7 @@ class PredictWorker(multiprocessing.Process):
def __init__(self, idx, gpuid, inqueue, outqueue, config): def __init__(self, idx, gpuid, inqueue, outqueue, config):
""" """
:param idx: index of the worker. the 0th worker will print log. :param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used :param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point :param inqueue: input queue to get data point
:param outqueue: output queue put result :param outqueue: output queue put result
:param config: a `PredictConfig` :param config: a `PredictConfig`
...@@ -115,10 +115,13 @@ class PredictWorker(multiprocessing.Process): ...@@ -115,10 +115,13 @@ class PredictWorker(multiprocessing.Process):
self.config = config self.config = config
def run(self): def run(self):
logger.info("Worker {} use GPU {}".format(self.idx, self.gpuid)) if self.gpuid >= 0:
logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
else:
logger.info("Worker {} uses CPU".format(self.idx))
G = tf.Graph() # build a graph for each process, because they don't need to share anything G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0'): with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'):
if self.idx != 0: if self.idx != 0:
from tensorpack.models._common import disable_layer_logging from tensorpack.models._common import disable_layer_logging
disable_layer_logging() disable_layer_logging()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment