Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ab229670
Commit
ab229670
authored
Aug 07, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Allow ignore_mismatch when loading a model to a session
parent
4d041d06
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
16 deletions
+46
-16
examples/FasterRCNN/dataset/coco.py
examples/FasterRCNN/dataset/coco.py
+1
-0
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+20
-7
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+25
-9
No files found.
examples/FasterRCNN/dataset/coco.py
View file @
ab229670
...
...
@@ -235,6 +235,7 @@ def register_coco(basedir):
DatasetRegistry
.
register
(
name
,
lambda
x
=
split
:
COCODetection
(
basedir
,
x
))
DatasetRegistry
.
register_metadata
(
name
,
'class_names'
,
class_names
)
if
__name__
==
'__main__'
:
basedir
=
'~/data/coco'
c
=
COCODetection
(
basedir
,
'train2014'
)
...
...
tensorpack/tfutils/sessinit.py
View file @
ab229670
...
...
@@ -174,7 +174,9 @@ class SaverRestoreRelaxed(SaverRestore):
def
f
(
reader
,
name
,
v
):
val
=
reader
.
get_tensor
(
name
)
v
.
load
(
SessionUpdate
.
relaxed_value_for_var
(
val
,
v
))
val
=
SessionUpdate
.
relaxed_value_for_var
(
val
,
v
,
ignore_mismatch
=
True
)
if
val
is
not
None
:
v
.
load
(
val
)
with
sess
.
as_default
():
self
.
_match_vars
(
f
)
...
...
@@ -185,14 +187,17 @@ class DictRestore(SessionInit):
Restore variables from a dictionary.
"""
def
__init__
(
self
,
variable_dict
):
def
__init__
(
self
,
variable_dict
,
ignore_mismatch
=
False
):
"""
Args:
variable_dict (dict): a dict of {name: value}
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
"""
assert
isinstance
(
variable_dict
,
dict
),
type
(
variable_dict
)
# use varname (with :0) for consistency
self
.
_prms
=
{
get_op_tensor_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
variable_dict
)}
self
.
_ignore_mismatch
=
ignore_mismatch
def
_run_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
...
...
@@ -218,7 +223,7 @@ class DictRestore(SessionInit):
mismatch
.
add
(
k
)
mismatch
.
log
()
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
]
,
ignore_mismatch
=
self
.
_ignore_mismatch
)
logger
.
info
(
"Restoring {} variables from dict ..."
.
format
(
len
(
intersect
)))
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
_prms
)
if
name
in
intersect
})
...
...
@@ -246,10 +251,15 @@ class ChainInit(SessionInit):
i
.
_run_init
(
sess
)
def
get_model_loader
(
filename
):
def
get_model_loader
(
filename
,
ignore_mismatch
=
False
):
"""
Get a corresponding model loader by looking at the file name.
Args:
filename (str): either a tensorflow checkpoint, or a npz file.
ignore_mismatch (bool): ignore failures when values in the file and
variables in the graph do not match.
Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise).
...
...
@@ -258,10 +268,13 @@ def get_model_loader(filename):
filename
=
os
.
path
.
expanduser
(
filename
)
if
filename
.
endswith
(
'.npy'
):
assert
tf
.
gfile
.
Exists
(
filename
),
filename
return
DictRestore
(
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
())
return
DictRestore
(
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
()
,
ignore_mismatch
=
ignore_mismatch
)
elif
filename
.
endswith
(
'.npz'
):
assert
tf
.
gfile
.
Exists
(
filename
),
filename
obj
=
np
.
load
(
filename
)
return
DictRestore
(
dict
(
obj
))
return
DictRestore
(
dict
(
obj
)
,
ignore_mismatch
=
ignore_mismatch
)
else
:
return
SaverRestore
(
filename
)
if
ignore_mismatch
:
return
SaverRestoreRelaxed
(
filename
)
else
:
return
SaverRestore
(
filename
)
tensorpack/tfutils/varmanip.py
View file @
ab229670
...
...
@@ -38,17 +38,20 @@ def get_savename_from_varname(
class
SessionUpdate
(
object
):
""" Update the variables in a session """
def
__init__
(
self
,
sess
,
vars_to_update
):
def
__init__
(
self
,
sess
,
vars_to_update
,
ignore_mismatch
=
False
):
"""
Args:
sess (tf.Session): a session object
vars_to_update: a collection of variables to update
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
"""
self
.
sess
=
sess
self
.
name_map
=
{
v
.
name
:
v
for
v
in
vars_to_update
}
self
.
ignore_mismatch
=
ignore_mismatch
@
staticmethod
def
relaxed_value_for_var
(
value
,
var
):
def
relaxed_value_for_var
(
value
,
var
,
ignore_mismatch
=
False
):
"""
Returns a relaxed (possibly reshaped/upcast-ed) version of value,
to be loaded to the given variable.
...
...
@@ -56,9 +59,13 @@ class SessionUpdate(object):
Args:
value (ndarray): an numpy array to be loaded to var
var (tf.Variable):
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
Returns:
ndarray: a possibly reshaped or casted version of value
ndarray: a possibly reshaped or casted version of value.
Returns None if `ignore_mismatch==True` and the value and the variable
mismatch.
"""
assert
isinstance
(
var
,
tf
.
Variable
)
name
=
var
.
op
.
name
...
...
@@ -66,11 +73,17 @@ class SessionUpdate(object):
# check incompatible shape
varshape
=
tuple
(
var
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
if
np
.
prod
(
varshape
)
!=
np
.
prod
(
value
.
shape
):
raise
ValueError
(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}."
.
format
(
value
.
shape
,
name
,
varshape
))
if
ignore_mismatch
:
logger
.
warn
(
"Cannot load a tensor of shape {} into the variable '{}' whose shape is {}."
.
format
(
value
.
shape
,
name
,
varshape
))
return
None
else
:
raise
ValueError
(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}."
.
format
(
value
.
shape
,
name
,
varshape
))
# TODO only allow reshape when shape different by empty axis
logger
.
warn
(
"The tensor is reshaped from {} to {} when assigned to '{}'"
.
format
(
value
.
shape
,
varshape
,
name
))
value
=
value
.
reshape
(
varshape
)
...
...
@@ -115,9 +128,12 @@ class SessionUpdate(object):
for
name
,
value
in
six
.
iteritems
(
prms
):
assert
name
in
self
.
name_map
var
=
self
.
name_map
[
name
]
fetches
.
append
(
var
.
initializer
)
value
=
SessionUpdate
.
relaxed_value_for_var
(
value
,
var
,
ignore_mismatch
=
self
.
ignore_mismatch
)
# This is the implementation of `var.load`
feeds
[
var
.
initializer
.
inputs
[
1
]]
=
SessionUpdate
.
relaxed_value_for_var
(
value
,
var
)
if
value
is
not
None
:
fetches
.
append
(
var
.
initializer
)
feeds
[
var
.
initializer
.
inputs
[
1
]]
=
value
self
.
sess
.
run
(
fetches
,
feed_dict
=
feeds
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment