Commit ab229670 authored by Yuxin Wu's avatar Yuxin Wu

Allow ignore_mismatch when loading a model to a session

parent 4d041d06
......@@ -235,6 +235,7 @@ def register_coco(basedir):
DatasetRegistry.register(name, lambda x=split: COCODetection(basedir, x))
DatasetRegistry.register_metadata(name, 'class_names', class_names)
if __name__ == '__main__':
basedir = '~/data/coco'
c = COCODetection(basedir, 'train2014')
......
......@@ -174,7 +174,9 @@ class SaverRestoreRelaxed(SaverRestore):
def f(reader, name, v):
val = reader.get_tensor(name)
v.load(SessionUpdate.relaxed_value_for_var(val, v))
val = SessionUpdate.relaxed_value_for_var(val, v, ignore_mismatch=True)
if val is not None:
v.load(val)
with sess.as_default():
self._match_vars(f)
......@@ -185,14 +187,17 @@ class DictRestore(SessionInit):
Restore variables from a dictionary.
"""
def __init__(self, variable_dict):
def __init__(self, variable_dict, ignore_mismatch=False):
"""
Args:
variable_dict (dict): a dict of {name: value}
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
"""
assert isinstance(variable_dict, dict), type(variable_dict)
# use varname (with :0) for consistency
self._prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(variable_dict)}
self._ignore_mismatch = ignore_mismatch
def _run_init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
......@@ -218,7 +223,7 @@ class DictRestore(SessionInit):
mismatch.add(k)
mismatch.log()
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect], ignore_mismatch=self._ignore_mismatch)
logger.info("Restoring {} variables from dict ...".format(len(intersect)))
upd.update({name: value for name, value in six.iteritems(self._prms) if name in intersect})
......@@ -246,10 +251,15 @@ class ChainInit(SessionInit):
i._run_init(sess)
def get_model_loader(filename):
def get_model_loader(filename, ignore_mismatch=False):
"""
Get a corresponding model loader by looking at the file name.
Args:
filename (str): either a tensorflow checkpoint, or a npz file.
ignore_mismatch (bool): ignore failures when values in the file and
variables in the graph do not match.
Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise).
......@@ -258,10 +268,13 @@ def get_model_loader(filename):
filename = os.path.expanduser(filename)
if filename.endswith('.npy'):
assert tf.gfile.Exists(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item())
return DictRestore(np.load(filename, encoding='latin1').item(), ignore_mismatch=ignore_mismatch)
elif filename.endswith('.npz'):
assert tf.gfile.Exists(filename), filename
obj = np.load(filename)
return DictRestore(dict(obj))
return DictRestore(dict(obj), ignore_mismatch=ignore_mismatch)
else:
if ignore_mismatch:
return SaverRestoreRelaxed(filename)
else:
return SaverRestore(filename)
......@@ -38,17 +38,20 @@ def get_savename_from_varname(
class SessionUpdate(object):
""" Update the variables in a session """
def __init__(self, sess, vars_to_update):
def __init__(self, sess, vars_to_update, ignore_mismatch=False):
"""
Args:
sess (tf.Session): a session object
vars_to_update: a collection of variables to update
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
"""
self.sess = sess
self.name_map = {v.name: v for v in vars_to_update}
self.ignore_mismatch = ignore_mismatch
@staticmethod
def relaxed_value_for_var(value, var):
def relaxed_value_for_var(value, var, ignore_mismatch=False):
"""
Returns a relaxed (possibly reshaped/upcast-ed) version of value,
to be loaded to the given variable.
......@@ -56,9 +59,13 @@ class SessionUpdate(object):
Args:
value (ndarray): an numpy array to be loaded to var
var (tf.Variable):
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
Returns:
ndarray: a possibly reshaped or casted version of value
ndarray: a possibly reshaped or casted version of value.
Returns None if `ignore_mismatch==True` and the value and the variable
mismatch.
"""
assert isinstance(var, tf.Variable)
name = var.op.name
......@@ -66,11 +73,17 @@ class SessionUpdate(object):
# check incompatible shape
varshape = tuple(var.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
if np.prod(varshape) != np.prod(value.shape):
if ignore_mismatch:
logger.warn(
"Cannot load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
value.shape, name, varshape))
return None
else:
raise ValueError(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
value.shape, name, varshape))
# TODO only allow reshape when shape different by empty axis
logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format(
value.shape, varshape, name))
value = value.reshape(varshape)
......@@ -115,9 +128,12 @@ class SessionUpdate(object):
for name, value in six.iteritems(prms):
assert name in self.name_map
var = self.name_map[name]
fetches.append(var.initializer)
value = SessionUpdate.relaxed_value_for_var(
value, var, ignore_mismatch=self.ignore_mismatch)
# This is the implementation of `var.load`
feeds[var.initializer.inputs[1]] = SessionUpdate.relaxed_value_for_var(value, var)
if value is not None:
fetches.append(var.initializer)
feeds[var.initializer.inputs[1]] = value
self.sess.run(fetches, feed_dict=feeds)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment