Commit f603636c authored by Yuxin Wu's avatar Yuxin Wu

group warning messages in SessionInit

parent 019ff1a5
...@@ -78,6 +78,21 @@ class CheckpointReaderAdapter(object): ...@@ -78,6 +78,21 @@ class CheckpointReaderAdapter(object):
return name[:-2] return name[:-2]
class MismatchLogger(object):
def __init__(self, exists, nonexists):
self._exists = exists
self._nonexists = nonexists
self._names = []
def add(self, name):
self._names.append(name)
def log(self):
if len(self._names):
logger.warn("The following variables are in the {}, but not found in the {}: {}".format(
self._exists, self._nonexists, ', '.join(self._names)))
class SaverRestore(SessionInit): class SaverRestore(SessionInit):
""" """
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`. Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
...@@ -114,6 +129,8 @@ class SaverRestore(SessionInit): ...@@ -114,6 +129,8 @@ class SaverRestore(SessionInit):
reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
graph_vars = tf.global_variables() graph_vars = tf.global_variables()
chkpt_vars_used = set() chkpt_vars_used = set()
mismatch = MismatchLogger('graph', 'checkpoint')
for v in graph_vars: for v in graph_vars:
name = get_savename_from_varname(v.name, varname_prefix=self.prefix) name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
if name in self.ignore and reader.has_tensor(name): if name in self.ignore and reader.has_tensor(name):
...@@ -125,12 +142,15 @@ class SaverRestore(SessionInit): ...@@ -125,12 +142,15 @@ class SaverRestore(SessionInit):
else: else:
vname = v.op.name vname = v.op.name
if not is_training_name(vname): if not is_training_name(vname):
logger.warn("Variable {} in the graph not found in checkpoint!".format(vname)) mismatch.add(vname)
mismatch.log()
mismatch = MismatchLogger('checkpoint', 'graph')
if len(chkpt_vars_used) < len(chkpt_vars): if len(chkpt_vars_used) < len(chkpt_vars):
unused = chkpt_vars - chkpt_vars_used unused = chkpt_vars - chkpt_vars_used
for name in sorted(unused): for name in sorted(unused):
if not is_training_name(name): if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name)) mismatch.add(name)
mismatch.log()
def _get_restore_dict(self): def _get_restore_dict(self):
var_dict = {} var_dict = {}
...@@ -185,11 +205,16 @@ class DictRestore(SessionInit): ...@@ -185,11 +205,16 @@ class DictRestore(SessionInit):
logger.info("Params to restore: {}".format( logger.info("Params to restore: {}".format(
', '.join(map(str, intersect)))) ', '.join(map(str, intersect))))
mismatch = MismatchLogger('graph', 'dict')
for k in sorted(variable_names - param_names): for k in sorted(variable_names - param_names):
if not is_training_name(k): if not is_training_name(k):
logger.warn("Variable {} in the graph not found in the dict!".format(k)) mismatch.add(k)
mismatch.log()
mismatch = MismatchLogger('dict', 'graph')
for k in sorted(param_names - variable_names): for k in sorted(param_names - variable_names):
logger.warn("Variable {} in the dict not found in the graph!".format(k)) mismatch.add(k)
mismatch.log()
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect]) upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from dict ...") logger.info("Restoring from dict ...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment