Commit 6be159d3 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Add ignore to SaverRestore (#302)

* Add ignore to SaverRestore

When fine-tuning a model, one might want to skip some tensors from
restoring.

The following example will not load the learning rate from the
checkpoint file:

  config.session_init = SaverRestore(args.load,
      ignore=['learning_rate'])

* Update sessinit.py
parent 9c5fa4c4
......@@ -82,15 +82,17 @@ class SaverRestore(SessionInit):
"""
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
"""
def __init__(self, model_path, prefix=None):
def __init__(self, model_path, prefix=None, ignore=[]):
"""
Args:
model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
ignore (list[str]): list of tensor names that should be ignored during loading, e.g. learning-rate
"""
model_path = get_checkpoint_path(model_path)
self.path = model_path
self.prefix = prefix
self.ignore = [i if i.endswith(':0') else i + ':0' for i in ignore]
def _setup_graph(self):
dic = self._get_restore_dict()
......@@ -114,6 +116,9 @@ class SaverRestore(SessionInit):
chkpt_vars_used = set()
for v in graph_vars:
name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
if name in self.ignore and reader.has_tensor(name):
logger.info("Variable {} in the graph will be not loaded from the checkpoint!".format(name))
else:
if reader.has_tensor(name):
func(reader, name, v)
chkpt_vars_used.add(name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment