Commit 095c1cd9 authored by Yuxin Wu's avatar Yuxin Wu

fix repeated param restore

parent 2cd41e99
......@@ -110,6 +110,7 @@ class SaverRestore(SessionInit):
"""
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list)
chkpt_vars_used = set()
for v in vars_to_restore:
name = v.op.name
if 'towerp' in name:
......@@ -123,12 +124,14 @@ class SaverRestore(SessionInit):
name = name[len(self.prefix)+1:]
if name in vars_available:
var_dict[name].append(v)
vars_available.remove(name)
chkpt_vars_used.add(name)
#vars_available.remove(name)
else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
# TODO warn if some variable in checkpoint is not used
#for name in vars_available:
#logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used
for name in unused:
logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
return var_dict
class ParamRestore(SessionInit):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment