Commit 6a2425d0 authored by Yuxin Wu's avatar Yuxin Wu

better variable name management

parent 4ee67733
...@@ -8,6 +8,7 @@ import re ...@@ -8,6 +8,7 @@ import re
from .base import Callback from .base import Callback
from ..utils import * from ..utils import *
from ..tfutils.varmanip import get_savename_from_varname
__all__ = ['ModelSaver'] __all__ = ['ModelSaver']
...@@ -15,37 +16,42 @@ class ModelSaver(Callback): ...@@ -15,37 +16,42 @@ class ModelSaver(Callback):
""" """
Save the model to logger directory. Save the model to logger directory.
""" """
def __init__(self, keep_recent=10, keep_freq=0.5): def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=tf.GraphKeys.VARIABLES):
""" """
:param keep_recent: see `tf.train.Saver` documentation. :param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation. :param keep_freq: see `tf.train.Saver` documentation.
""" """
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.keep_freq = keep_freq self.keep_freq = keep_freq
if not isinstance(var_collections, list):
var_collections = [var_collections]
self.var_collections = var_collections
def _setup_graph(self): def _setup_graph(self):
vars = []
for key in self.var_collections:
vars.extend(tf.get_collection(key))
self.path = os.path.join(logger.LOG_DIR, 'model') self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
var_list=ModelSaver._get_vars(), var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent, max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq) keep_checkpoint_every_n_hours=self.keep_freq)
self.meta_graph_written = False self.meta_graph_written = False
@staticmethod @staticmethod
def _get_vars(): def _get_var_dict(vars):
vars = tf.all_variables()
var_dict = {} var_dict = {}
for v in vars: for v in vars:
name = v.name name = get_savename_from_varname(v.name)
if re.match('tower[p1-9]', name): if name not in var_dict:
#logger.info("Skip {} when saving model.".format(name)) if name != v.name:
continue
if 'tower0/' in name:
new_name = name.replace('tower0/', '')
logger.info( logger.info(
"{} renamed to {} when saving model.".format(name, new_name)) "{} renamed to {} when saving model.".format(v.name, name))
name = new_name
var_dict[name] = v var_dict[name] = v
else:
logger.warn("Variable {} won't be saved \
because {} will be saved".format(v.name, var_dict[name].name))
return var_dict return var_dict
def _trigger_epoch(self): def _trigger_epoch(self):
......
...@@ -15,7 +15,7 @@ from ..tfutils import get_op_var_name ...@@ -15,7 +15,7 @@ from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter', __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter', 'ScheduledHyperParamSetter',
'NonDecreasingStatMonitorParamSetter', 'StatMonitorParamSetter',
'HyperParam', 'GraphVarParam', 'ObjAttrParam'] 'HyperParam', 'GraphVarParam', 'ObjAttrParam']
class HyperParam(object): class HyperParam(object):
...@@ -176,14 +176,15 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -176,14 +176,15 @@ class ScheduledHyperParamSetter(HyperParamSetter):
return v return v
return None return None
class NonDecreasingStatMonitorParamSetter(HyperParamSetter): class StatMonitorParamSetter(HyperParamSetter):
""" """
Set hyperparameter by a func, if a specific stat wasn't Set hyperparameter by a func, if a specific stat wasn't
monotonically decreasing $a$ times out of the last $b$ epochs monotonically decreasing/increasing $a$ times out of the last $b$ epochs
""" """
def __init__(self, param, stat_name, value_func, def __init__(self, param, stat_name, value_func,
last_k=5, last_k=5,
min_non_decreasing=2 min_non_decreasing=2,
reverse=False
): ):
""" """
Change param by `new_value = value_func(old_value)`, Change param by `new_value = value_func(old_value)`,
...@@ -192,6 +193,8 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter): ...@@ -192,6 +193,8 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
For example, if error wasn't decreasing, anneal the learning rate: For example, if error wasn't decreasing, anneal the learning rate:
NonDecreasingStatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2) NonDecreasingStatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
If reverse==True, use 'increasing' instead of decreasing
""" """
super(NonDecreasingStatMonitorParamSetter, self).__init__(param) super(NonDecreasingStatMonitorParamSetter, self).__init__(param)
self.stat_name = stat_name self.stat_name = stat_name
...@@ -200,6 +203,11 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter): ...@@ -200,6 +203,11 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
self.min_non_decreasing = min_non_decreasing self.min_non_decreasing = min_non_decreasing
self.last_changed_epoch = 0 self.last_changed_epoch = 0
if not reverse:
self.less_than = lambda x, y: x <= y
else:
self.less_than = lambda x, y: x >= y
def _get_value_to_set(self): def _get_value_to_set(self):
holder = self.trainer.stat_holder holder = self.trainer.stat_holder
hist = holder.get_stat_history(self.stat_name) hist = holder.get_stat_history(self.stat_name)
...@@ -209,10 +217,10 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter): ...@@ -209,10 +217,10 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
hist = hist[-self.last_k-1:] # len==last_k+1 hist = hist[-self.last_k-1:] # len==last_k+1
cnt = 0 cnt = 0
for k in range(self.last_k): for k in range(self.last_k):
if hist[k] <= hist[k+1]: if self.less_than(hist[k], hist[k+1]):
cnt += 1 cnt += 1
if cnt >= self.min_non_decreasing \ if cnt >= self.min_non_decreasing \
and hist[-1] >= hist[0]: and self.less_than(hist[0], hist[-1]):
return self.value_func(self.get_current_value()) return self.value_func(self.get_current_value())
return None return None
...@@ -36,7 +36,6 @@ def get_savename_from_varname( ...@@ -36,7 +36,6 @@ def get_savename_from_varname(
name = savename_prefix + '/' + name name = savename_prefix + '/' + name
return name return name
class SessionUpdate(object): class SessionUpdate(object):
""" Update the variables in a session """ """ Update the variables in a session """
def __init__(self, sess, vars_to_update): def __init__(self, sess, vars_to_update):
...@@ -87,7 +86,10 @@ def dump_session_params(path): ...@@ -87,7 +86,10 @@ def dump_session_params(path):
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY)) var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
result = {} result = {}
for v in var: for v in var:
name = v.name.replace(":0", "") name = get_savename_from_varname(v.name)
if name in result:
logger.info("Variable {} would be stored instead of another with \
the same name".format(v.name))
result[name] = v.eval() result[name] = v.eval()
logger.info("Variables to save to {}:".format(path)) logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys())) logger.info(str(result.keys()))
......
...@@ -11,7 +11,7 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' ...@@ -11,7 +11,7 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables # placeholders for input variables
INPUT_VARS_KEY = 'INPUT_VARIABLES' INPUT_VARS_KEY = 'INPUT_VARIABLES'
# variables that need to be saved, apart from trainable variables # variables that need to be saved for inference, apart from trainable variables
EXTRA_SAVE_VARS_KEY = 'EXTRA_SAVE_VARIABLES' EXTRA_SAVE_VARS_KEY = 'EXTRA_SAVE_VARIABLES'
import tensorflow as tf import tensorflow as tf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment