Commit 4be41492 authored by Yuxin Wu's avatar Yuxin Wu

simplify code of two saverrestore

parent cc89b105
...@@ -118,17 +118,14 @@ class SaverRestore(SessionInit): ...@@ -118,17 +118,14 @@ class SaverRestore(SessionInit):
ckpt_vars = reader.get_variable_to_shape_map().keys() ckpt_vars = reader.get_variable_to_shape_map().keys()
return reader, set(ckpt_vars) return reader, set(ckpt_vars)
def _get_restore_dict(self): def _match_vars(self, func):
reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
graph_vars = tf.global_variables() graph_vars = tf.global_variables()
var_dict = {}
chkpt_vars_used = set() chkpt_vars_used = set()
for v in graph_vars: for v in graph_vars:
name = get_savename_from_varname(v.name, varname_prefix=self.prefix) name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
if reader.has_tensor(name): if reader.has_tensor(name):
ckpt_name = reader.get_real_name(name) func(reader, name, v)
assert ckpt_name not in var_dict, "Restore conflict: {} and {}".format(v.name, var_dict[ckpt_name].name)
var_dict[ckpt_name] = v
chkpt_vars_used.add(name) chkpt_vars_used.add(name)
else: else:
vname = v.op.name vname = v.op.name
...@@ -139,6 +136,15 @@ class SaverRestore(SessionInit): ...@@ -139,6 +136,15 @@ class SaverRestore(SessionInit):
for name in sorted(unused): for name in sorted(unused):
if not is_training_name(name): if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name)) logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
def _get_restore_dict(self):
var_dict = {}
def f(reader, name, v):
name = reader.get_real_name(name)
assert name not in var_dict, "Restore conflict: {} and {}".format(v.name, var_dict[name].name)
var_dict[name] = v
self._match_vars(f)
return var_dict return var_dict
...@@ -153,26 +159,12 @@ class SaverRestoreRelaxed(SaverRestore): ...@@ -153,26 +159,12 @@ class SaverRestoreRelaxed(SaverRestore):
def _run_init(self, sess): def _run_init(self, sess):
logger.info( logger.info(
"Restoring checkpoint from {} ...".format(self.path)) "Restoring checkpoint from {} ...".format(self.path))
reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
graph_vars = tf.global_variables()
chkpt_vars_used = set()
def f(reader, name, v):
val = reader.get_tensor(name)
SessionUpdate.load_value_to_var(v, val)
with sess.as_default(): with sess.as_default():
for v in graph_vars: self._match_vars(f)
name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
if name in chkpt_vars:
val = reader.get_tensor(name)
SessionUpdate.load_value_to_var(v, val)
chkpt_vars_used.add(name)
else:
vname = v.op.name
if not is_training_name(vname):
logger.warn("Variable {} in the graph not found in checkpoint!".format(vname))
if len(chkpt_vars_used) < len(chkpt_vars):
unused = chkpt_vars - chkpt_vars_used
for name in sorted(unused):
if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment