Commit b9498a1a authored by Yuxin Wu's avatar Yuxin Wu

improve logging on model loading

parent 78c7488e
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# File: pool.py # File: pool.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import numpy import numpy as np
from ._common import layer_register, shape2d, shape4d from ._common import layer_register, shape2d, shape4d
from ..tfutils import symbolic_functions as symbf from ..tfutils import symbolic_functions as symbf
......
...@@ -39,3 +39,4 @@ def get_shape_str(tensors): ...@@ -39,3 +39,4 @@ def get_shape_str(tensors):
shape_str = str(tensors.get_shape().as_list()) shape_str = str(tensors.get_shape().as_list())
return shape_str return shape_str
...@@ -10,9 +10,9 @@ import numpy as np ...@@ -10,9 +10,9 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils import logger, EXTRA_SAVE_VARS_KEY from ..utils import logger
from .common import get_op_var_name from .common import get_op_var_name
from .varmanip import SessionUpdate, get_savename_from_varname from .varmanip import SessionUpdate, get_savename_from_varname, is_training_specific_name
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
...@@ -127,7 +127,8 @@ class SaverRestore(SessionInit): ...@@ -127,7 +127,8 @@ class SaverRestore(SessionInit):
var_dict[name].append(v) var_dict[name].append(v)
chkpt_vars_used.add(name) chkpt_vars_used.add(name)
else: else:
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name)) if not is_training_specific_name(v.op.name):
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available): if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used unused = vars_available - chkpt_vars_used
for name in unused: for name in unused:
...@@ -156,7 +157,8 @@ class ParamRestore(SessionInit): ...@@ -156,7 +157,8 @@ class ParamRestore(SessionInit):
logger.info("Params to restore: {}".format( logger.info("Params to restore: {}".format(
', '.join(map(str, intersect)))) ', '.join(map(str, intersect))))
for k in variable_names - param_names: for k in variable_names - param_names:
logger.warn("Variable {} in the graph not found in the dict!".format(k)) if not is_training_specific_name(k):
logger.warn("Variable {} in the graph not found in the dict!".format(k))
for k in param_names - variable_names: for k in param_names - variable_names:
logger.warn("Variable {} in the dict not found in the graph!".format(k)) logger.warn("Variable {} in the dict not found in the graph!".format(k))
......
...@@ -106,12 +106,13 @@ def summary_moving_average(): ...@@ -106,12 +106,13 @@ def summary_moving_average():
:returns: a op to maintain these average. :returns: a op to maintain these average.
""" """
with tf.name_scope('EMA_summary'): with tf.name_scope('EMA_summary'):
# TODO will produce EMA_summary/tower0/xxx. not elegant
global_step_var = get_global_step_var() global_step_var = get_global_step_var()
with tf.name_scope(None): with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
0.99, num_updates=global_step_var, name='EMA') 0.99, num_updates=global_step_var, name='EMA')
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY) vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary) avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary): for idx, c in enumerate(vars_to_summary):
name = re.sub('tower[p0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c)) tf.scalar_summary(name, averager.average(c))
......
...@@ -10,9 +10,10 @@ import re ...@@ -10,9 +10,10 @@ import re
import numpy as np import numpy as np
from ..utils import logger from ..utils import logger
from ..utils.naming import * from ..utils.naming import *
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname'] 'get_savename_from_varname', 'is_training_specific_name']
def get_savename_from_varname( def get_savename_from_varname(
varname, varname_prefix=None, varname, varname_prefix=None,
...@@ -24,7 +25,7 @@ def get_savename_from_varname( ...@@ -24,7 +25,7 @@ def get_savename_from_varname(
:returns: the name used to save the variable :returns: the name used to save the variable
""" """
name = varname name = varname
if 'towerp' in name: if 'towerp/' in name:
logger.error("No variable should be under 'towerp' name scope".format(v.name)) logger.error("No variable should be under 'towerp' name scope".format(v.name))
# don't overwrite anything in the current prediction graph # don't overwrite anything in the current prediction graph
return None return None
...@@ -95,3 +96,24 @@ def dump_chkpt_vars(model_path): ...@@ -95,3 +96,24 @@ def dump_chkpt_vars(model_path):
for n in var_names: for n in var_names:
result[n] = reader.get_tensor(n) result[n] = reader.get_tensor(n)
return result return result
def is_training_specific_name(name):
"""
This is only used to improve logging.
:returns: guess whether this tensor is something only used in training.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and EXTRA_SAVE_VARS_KEY ?
name = get_op_tensor_name(name)[0]
if name.endswith('/Adam') or name.endswith('/Adam_1'):
return True
if name.endswith('/Momentum'):
return True
if name.endswith('/Adadelta') or name.endswith('/Adadelta_1'):
return True
if name.endswith('/RMSProp') or name.endswith('/RMSProp_1'):
return True
if name.endswith('/Adagrad'):
return True
if 'EMA_summary/' in name:
return True
return False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment