Commit 6a2425d0 authored by Yuxin Wu's avatar Yuxin Wu

better variable name management

parent 4ee67733
......@@ -8,6 +8,7 @@ import re
from .base import Callback
from ..utils import *
from ..tfutils.varmanip import get_savename_from_varname
__all__ = ['ModelSaver']
......@@ -15,37 +16,42 @@ class ModelSaver(Callback):
"""
Save the model to logger directory.
"""
def __init__(self, keep_recent=10, keep_freq=0.5):
def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=tf.GraphKeys.VARIABLES):
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
self.keep_recent = keep_recent
self.keep_freq = keep_freq
if not isinstance(var_collections, list):
var_collections = [var_collections]
self.var_collections = var_collections
def _setup_graph(self):
vars = []
for key in self.var_collections:
vars.extend(tf.get_collection(key))
self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver(
var_list=ModelSaver._get_vars(),
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
self.meta_graph_written = False
@staticmethod
def _get_vars():
vars = tf.all_variables()
def _get_var_dict(vars):
var_dict = {}
for v in vars:
name = v.name
if re.match('tower[p1-9]', name):
#logger.info("Skip {} when saving model.".format(name))
continue
if 'tower0/' in name:
new_name = name.replace('tower0/', '')
logger.info(
"{} renamed to {} when saving model.".format(name, new_name))
name = new_name
var_dict[name] = v
name = get_savename_from_varname(v.name)
if name not in var_dict:
if name != v.name:
logger.info(
"{} renamed to {} when saving model.".format(v.name, name))
var_dict[name] = v
else:
logger.warn("Variable {} won't be saved \
because {} will be saved".format(v.name, var_dict[name].name))
return var_dict
def _trigger_epoch(self):
......
......@@ -15,7 +15,7 @@ from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter',
'NonDecreasingStatMonitorParamSetter',
'StatMonitorParamSetter',
'HyperParam', 'GraphVarParam', 'ObjAttrParam']
class HyperParam(object):
......@@ -176,14 +176,15 @@ class ScheduledHyperParamSetter(HyperParamSetter):
return v
return None
class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
class StatMonitorParamSetter(HyperParamSetter):
"""
Set hyperparameter by a func, if a specific stat wasn't
monotonically decreasing $a$ times out of the last $b$ epochs
monotonically decreasing/increasing $a$ times out of the last $b$ epochs
"""
def __init__(self, param, stat_name, value_func,
last_k=5,
min_non_decreasing=2
min_non_decreasing=2,
reverse=False
):
"""
Change param by `new_value = value_func(old_value)`,
......@@ -192,6 +193,8 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
For example, if error wasn't decreasing, anneal the learning rate:
NonDecreasingStatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
If reverse==True, use 'increasing' instead of decreasing
"""
super(NonDecreasingStatMonitorParamSetter, self).__init__(param)
self.stat_name = stat_name
......@@ -200,6 +203,11 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
self.min_non_decreasing = min_non_decreasing
self.last_changed_epoch = 0
if not reverse:
self.less_than = lambda x, y: x <= y
else:
self.less_than = lambda x, y: x >= y
def _get_value_to_set(self):
holder = self.trainer.stat_holder
hist = holder.get_stat_history(self.stat_name)
......@@ -209,10 +217,10 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
hist = hist[-self.last_k-1:] # len==last_k+1
cnt = 0
for k in range(self.last_k):
if hist[k] <= hist[k+1]:
if self.less_than(hist[k], hist[k+1]):
cnt += 1
if cnt >= self.min_non_decreasing \
and hist[-1] >= hist[0]:
and self.less_than(hist[0], hist[-1]):
return self.value_func(self.get_current_value())
return None
......@@ -36,7 +36,6 @@ def get_savename_from_varname(
name = savename_prefix + '/' + name
return name
class SessionUpdate(object):
""" Update the variables in a session """
def __init__(self, sess, vars_to_update):
......@@ -87,7 +86,10 @@ def dump_session_params(path):
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
result = {}
for v in var:
name = v.name.replace(":0", "")
name = get_savename_from_varname(v.name)
if name in result:
logger.info("Variable {} would be stored instead of another with \
the same name".format(v.name))
result[name] = v.eval()
logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys()))
......
......@@ -11,7 +11,7 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables
INPUT_VARS_KEY = 'INPUT_VARIABLES'
# variables that need to be saved, apart from trainable variables
# variables that need to be saved for inference, apart from trainable variables
EXTRA_SAVE_VARS_KEY = 'EXTRA_SAVE_VARIABLES'
import tensorflow as tf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment