Commit ab229670 authored by Yuxin Wu's avatar Yuxin Wu

Allow ignore_mismatch when loading a model to a session

parent 4d041d06
...@@ -235,6 +235,7 @@ def register_coco(basedir): ...@@ -235,6 +235,7 @@ def register_coco(basedir):
DatasetRegistry.register(name, lambda x=split: COCODetection(basedir, x)) DatasetRegistry.register(name, lambda x=split: COCODetection(basedir, x))
DatasetRegistry.register_metadata(name, 'class_names', class_names) DatasetRegistry.register_metadata(name, 'class_names', class_names)
if __name__ == '__main__': if __name__ == '__main__':
basedir = '~/data/coco' basedir = '~/data/coco'
c = COCODetection(basedir, 'train2014') c = COCODetection(basedir, 'train2014')
......
...@@ -174,7 +174,9 @@ class SaverRestoreRelaxed(SaverRestore): ...@@ -174,7 +174,9 @@ class SaverRestoreRelaxed(SaverRestore):
def f(reader, name, v): def f(reader, name, v):
val = reader.get_tensor(name) val = reader.get_tensor(name)
v.load(SessionUpdate.relaxed_value_for_var(val, v)) val = SessionUpdate.relaxed_value_for_var(val, v, ignore_mismatch=True)
if val is not None:
v.load(val)
with sess.as_default(): with sess.as_default():
self._match_vars(f) self._match_vars(f)
...@@ -185,14 +187,17 @@ class DictRestore(SessionInit): ...@@ -185,14 +187,17 @@ class DictRestore(SessionInit):
Restore variables from a dictionary. Restore variables from a dictionary.
""" """
def __init__(self, variable_dict): def __init__(self, variable_dict, ignore_mismatch=False):
""" """
Args: Args:
variable_dict (dict): a dict of {name: value} variable_dict (dict): a dict of {name: value}
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
""" """
assert isinstance(variable_dict, dict), type(variable_dict) assert isinstance(variable_dict, dict), type(variable_dict)
# use varname (with :0) for consistency # use varname (with :0) for consistency
self._prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(variable_dict)} self._prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(variable_dict)}
self._ignore_mismatch = ignore_mismatch
def _run_init(self, sess): def _run_init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
...@@ -218,7 +223,7 @@ class DictRestore(SessionInit): ...@@ -218,7 +223,7 @@ class DictRestore(SessionInit):
mismatch.add(k) mismatch.add(k)
mismatch.log() mismatch.log()
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect]) upd = SessionUpdate(sess, [v for v in variables if v.name in intersect], ignore_mismatch=self._ignore_mismatch)
logger.info("Restoring {} variables from dict ...".format(len(intersect))) logger.info("Restoring {} variables from dict ...".format(len(intersect)))
upd.update({name: value for name, value in six.iteritems(self._prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self._prms) if name in intersect})
...@@ -246,10 +251,15 @@ class ChainInit(SessionInit): ...@@ -246,10 +251,15 @@ class ChainInit(SessionInit):
i._run_init(sess) i._run_init(sess)
def get_model_loader(filename): def get_model_loader(filename, ignore_mismatch=False):
""" """
Get a corresponding model loader by looking at the file name. Get a corresponding model loader by looking at the file name.
Args:
filename (str): either a tensorflow checkpoint, or a npz file.
ignore_mismatch (bool): ignore failures when values in the file and
variables in the graph do not match.
Returns: Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise). :class:`SaverRestore` (otherwise).
...@@ -258,10 +268,13 @@ def get_model_loader(filename): ...@@ -258,10 +268,13 @@ def get_model_loader(filename):
filename = os.path.expanduser(filename) filename = os.path.expanduser(filename)
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert tf.gfile.Exists(filename), filename assert tf.gfile.Exists(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item()) return DictRestore(np.load(filename, encoding='latin1').item(), ignore_mismatch=ignore_mismatch)
elif filename.endswith('.npz'): elif filename.endswith('.npz'):
assert tf.gfile.Exists(filename), filename assert tf.gfile.Exists(filename), filename
obj = np.load(filename) obj = np.load(filename)
return DictRestore(dict(obj)) return DictRestore(dict(obj), ignore_mismatch=ignore_mismatch)
else:
if ignore_mismatch:
return SaverRestoreRelaxed(filename)
else: else:
return SaverRestore(filename) return SaverRestore(filename)
...@@ -38,17 +38,20 @@ def get_savename_from_varname( ...@@ -38,17 +38,20 @@ def get_savename_from_varname(
class SessionUpdate(object): class SessionUpdate(object):
""" Update the variables in a session """ """ Update the variables in a session """
def __init__(self, sess, vars_to_update): def __init__(self, sess, vars_to_update, ignore_mismatch=False):
""" """
Args: Args:
sess (tf.Session): a session object sess (tf.Session): a session object
vars_to_update: a collection of variables to update vars_to_update: a collection of variables to update
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
""" """
self.sess = sess self.sess = sess
self.name_map = {v.name: v for v in vars_to_update} self.name_map = {v.name: v for v in vars_to_update}
self.ignore_mismatch = ignore_mismatch
@staticmethod @staticmethod
def relaxed_value_for_var(value, var): def relaxed_value_for_var(value, var, ignore_mismatch=False):
""" """
Returns a relaxed (possibly reshaped/upcast-ed) version of value, Returns a relaxed (possibly reshaped/upcast-ed) version of value,
to be loaded to the given variable. to be loaded to the given variable.
...@@ -56,9 +59,13 @@ class SessionUpdate(object): ...@@ -56,9 +59,13 @@ class SessionUpdate(object):
Args: Args:
value (ndarray): an numpy array to be loaded to var value (ndarray): an numpy array to be loaded to var
var (tf.Variable): var (tf.Variable):
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
Returns: Returns:
ndarray: a possibly reshaped or casted version of value ndarray: a possibly reshaped or casted version of value.
Returns None if `ignore_mismatch==True` and the value and the variable
mismatch.
""" """
assert isinstance(var, tf.Variable) assert isinstance(var, tf.Variable)
name = var.op.name name = var.op.name
...@@ -66,11 +73,17 @@ class SessionUpdate(object): ...@@ -66,11 +73,17 @@ class SessionUpdate(object):
# check incompatible shape # check incompatible shape
varshape = tuple(var.get_shape().as_list()) varshape = tuple(var.get_shape().as_list())
if varshape != value.shape: if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
if np.prod(varshape) != np.prod(value.shape): if np.prod(varshape) != np.prod(value.shape):
if ignore_mismatch:
logger.warn(
"Cannot load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
value.shape, name, varshape))
return None
else:
raise ValueError( raise ValueError(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format( "Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
value.shape, name, varshape)) value.shape, name, varshape))
# TODO only allow reshape when shape different by empty axis
logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format( logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format(
value.shape, varshape, name)) value.shape, varshape, name))
value = value.reshape(varshape) value = value.reshape(varshape)
...@@ -115,9 +128,12 @@ class SessionUpdate(object): ...@@ -115,9 +128,12 @@ class SessionUpdate(object):
for name, value in six.iteritems(prms): for name, value in six.iteritems(prms):
assert name in self.name_map assert name in self.name_map
var = self.name_map[name] var = self.name_map[name]
fetches.append(var.initializer) value = SessionUpdate.relaxed_value_for_var(
value, var, ignore_mismatch=self.ignore_mismatch)
# This is the implementation of `var.load` # This is the implementation of `var.load`
feeds[var.initializer.inputs[1]] = SessionUpdate.relaxed_value_for_var(value, var) if value is not None:
fetches.append(var.initializer)
feeds[var.initializer.inputs[1]] = value
self.sess.run(fetches, feed_dict=feeds) self.sess.run(fetches, feed_dict=feeds)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment