Commit b9498a1a authored by Yuxin Wu's avatar Yuxin Wu

improve logging on model loading

parent 78c7488e
......@@ -3,7 +3,7 @@
# File: pool.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy
import numpy as np
from ._common import layer_register, shape2d, shape4d
from ..tfutils import symbolic_functions as symbf
......
......@@ -39,3 +39,4 @@ def get_shape_str(tensors):
shape_str = str(tensors.get_shape().as_list())
return shape_str
......@@ -10,9 +10,9 @@ import numpy as np
import tensorflow as tf
import six
from ..utils import logger, EXTRA_SAVE_VARS_KEY
from ..utils import logger
from .common import get_op_var_name
from .varmanip import SessionUpdate, get_savename_from_varname
from .varmanip import SessionUpdate, get_savename_from_varname, is_training_specific_name
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit',
......@@ -127,6 +127,7 @@ class SaverRestore(SessionInit):
var_dict[name].append(v)
chkpt_vars_used.add(name)
else:
if not is_training_specific_name(v.op.name):
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used
......@@ -156,6 +157,7 @@ class ParamRestore(SessionInit):
logger.info("Params to restore: {}".format(
', '.join(map(str, intersect))))
for k in variable_names - param_names:
if not is_training_specific_name(k):
logger.warn("Variable {} in the graph not found in the dict!".format(k))
for k in param_names - variable_names:
logger.warn("Variable {} in the dict not found in the graph!".format(k))
......
......@@ -106,6 +106,7 @@ def summary_moving_average():
:returns: a op to maintain these average.
"""
with tf.name_scope('EMA_summary'):
# TODO will produce EMA_summary/tower0/xxx. not elegant
global_step_var = get_global_step_var()
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
......
......@@ -10,9 +10,10 @@ import re
import numpy as np
from ..utils import logger
from ..utils.naming import *
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname']
'get_savename_from_varname', 'is_training_specific_name']
def get_savename_from_varname(
varname, varname_prefix=None,
......@@ -24,7 +25,7 @@ def get_savename_from_varname(
:returns: the name used to save the variable
"""
name = varname
if 'towerp' in name:
if 'towerp/' in name:
logger.error("No variable should be under 'towerp' name scope".format(v.name))
# don't overwrite anything in the current prediction graph
return None
......@@ -95,3 +96,24 @@ def dump_chkpt_vars(model_path):
for n in var_names:
result[n] = reader.get_tensor(n)
return result
def is_training_specific_name(name):
"""
This is only used to improve logging.
:returns: guess whether this tensor is something only used in training.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and EXTRA_SAVE_VARS_KEY ?
name = get_op_tensor_name(name)[0]
if name.endswith('/Adam') or name.endswith('/Adam_1'):
return True
if name.endswith('/Momentum'):
return True
if name.endswith('/Adadelta') or name.endswith('/Adadelta_1'):
return True
if name.endswith('/RMSProp') or name.endswith('/RMSProp_1'):
return True
if name.endswith('/Adagrad'):
return True
if 'EMA_summary/' in name:
return True
return False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment