Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
78595e71
Commit
78595e71
authored
Feb 18, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
significantly improves speed of DictRestore
Loading a large COCO model takes 50 sec -> 0.4 sec
parent
5a868442
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
43 additions
and
28 deletions
+43
-28
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+3
-0
examples/ImageNetModels/vgg16.py
examples/ImageNetModels/vgg16.py
+1
-0
tensorpack/callbacks/hooks.py
tensorpack/callbacks/hooks.py
+7
-3
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-1
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+27
-23
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+3
-1
No files found.
examples/FasterRCNN/basemodel.py
View file @
78595e71
...
@@ -15,6 +15,9 @@ from config import config as cfg
...
@@ -15,6 +15,9 @@ from config import config as cfg
@
layer_register
(
log_shape
=
True
)
@
layer_register
(
log_shape
=
True
)
def
GroupNorm
(
x
,
group
=
32
,
gamma_initializer
=
tf
.
constant_initializer
(
1.
)):
def
GroupNorm
(
x
,
group
=
32
,
gamma_initializer
=
tf
.
constant_initializer
(
1.
)):
"""
More code that reproduces the paper can be found at https://github.com/ppwwyyxx/GroupNorm-reproduce/.
"""
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
ndims
=
len
(
shape
)
assert
ndims
==
4
,
shape
assert
ndims
==
4
,
shape
...
...
examples/ImageNetModels/vgg16.py
View file @
78595e71
...
@@ -17,6 +17,7 @@ from imagenet_utils import ImageNetModel, fbresnet_augmentor, get_imagenet_dataf
...
@@ -17,6 +17,7 @@ from imagenet_utils import ImageNetModel, fbresnet_augmentor, get_imagenet_dataf
def
GroupNorm
(
x
,
group
,
gamma_initializer
=
tf
.
constant_initializer
(
1.
)):
def
GroupNorm
(
x
,
group
,
gamma_initializer
=
tf
.
constant_initializer
(
1.
)):
"""
"""
https://arxiv.org/abs/1803.08494
https://arxiv.org/abs/1803.08494
More code that reproduces the paper can be found at https://github.com/ppwwyyxx/GroupNorm-reproduce/.
"""
"""
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
ndims
=
len
(
shape
)
...
...
tensorpack/callbacks/hooks.py
View file @
78595e71
...
@@ -13,9 +13,13 @@ __all__ = ['CallbackToHook', 'HookToCallback']
...
@@ -13,9 +13,13 @@ __all__ = ['CallbackToHook', 'HookToCallback']
class
CallbackToHook
(
tfv1
.
train
.
SessionRunHook
):
class
CallbackToHook
(
tfv1
.
train
.
SessionRunHook
):
""" This is only for internal implementation of
"""
before_run/after_run callbacks.
Hooks are less powerful than callbacks so the conversion is incomplete.
You shouldn't need to use this.
It only converts the `before_run/after_run` calls.
This is only for internal implementation of
before_run/after_run callbacks.
You shouldn't need to use this.
"""
"""
def
__init__
(
self
,
cb
):
def
__init__
(
self
,
cb
):
...
...
tensorpack/tfutils/sessinit.py
View file @
78595e71
...
@@ -174,7 +174,8 @@ class SaverRestoreRelaxed(SaverRestore):
...
@@ -174,7 +174,8 @@ class SaverRestoreRelaxed(SaverRestore):
def
f
(
reader
,
name
,
v
):
def
f
(
reader
,
name
,
v
):
val
=
reader
.
get_tensor
(
name
)
val
=
reader
.
get_tensor
(
name
)
SessionUpdate
.
load_value_to_var
(
v
,
val
)
v
.
load
(
SessionUpdate
.
relaxed_value_for_var
(
val
,
v
))
with
sess
.
as_default
():
with
sess
.
as_default
():
self
.
_match_vars
(
f
)
self
.
_match_vars
(
f
)
...
...
tensorpack/tfutils/varmanip.py
View file @
78595e71
...
@@ -47,30 +47,32 @@ class SessionUpdate(object):
...
@@ -47,30 +47,32 @@ class SessionUpdate(object):
self
.
name_map
=
{
v
.
name
:
v
for
v
in
vars_to_update
}
self
.
name_map
=
{
v
.
name
:
v
for
v
in
vars_to_update
}
@
staticmethod
@
staticmethod
def
load_value_to_var
(
var
,
val
,
strict
=
False
):
def
relaxed_value_for_var
(
value
,
var
):
"""
"""
Call `var.load(val)` with the default session, with some type checks.
Returns a relaxed (possibly reshaped/upcast-ed) version of value,
to be loaded to the given variable.
Args:
Args:
value (ndarray): an numpy array to be loaded to var
var (tf.Variable):
var (tf.Variable):
strict (bool): Behave less strict if set to False.
Returns:
ndarray: a possibly reshaped or casted version of value
"""
"""
if
strict
:
assert
isinstance
(
var
,
tf
.
Variable
)
var
.
load
(
val
)
return
name
=
var
.
op
.
name
name
=
var
.
op
.
name
# check incompatible shape
# check incompatible shape
varshape
=
tuple
(
var
.
get_shape
()
.
as_list
())
varshape
=
tuple
(
var
.
get_shape
()
.
as_list
())
if
varshape
!=
val
.
shape
:
if
varshape
!=
val
ue
.
shape
:
# TODO only allow reshape when shape different by empty axis
# TODO only allow reshape when shape different by empty axis
if
np
.
prod
(
varshape
)
!=
np
.
prod
(
val
.
shape
):
if
np
.
prod
(
varshape
)
!=
np
.
prod
(
val
ue
.
shape
):
raise
ValueError
(
raise
ValueError
(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}."
.
format
(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}."
.
format
(
val
.
shape
,
name
,
varshape
))
val
ue
.
shape
,
name
,
varshape
))
logger
.
warn
(
"The tensor is reshaped from {} to {} when assigned to '{}'"
.
format
(
logger
.
warn
(
"The tensor is reshaped from {} to {} when assigned to '{}'"
.
format
(
val
.
shape
,
varshape
,
name
))
val
ue
.
shape
,
varshape
,
name
))
val
=
val
.
reshape
(
varshape
)
val
ue
=
value
.
reshape
(
varshape
)
# fix some common type incompatibility problems, but not all
# fix some common type incompatibility problems, but not all
def
upcast
(
vartype
,
valtype
):
def
upcast
(
vartype
,
valtype
):
...
@@ -81,20 +83,17 @@ class SessionUpdate(object):
...
@@ -81,20 +83,17 @@ class SessionUpdate(object):
return
np
.
int64
if
vartype
==
tf
.
int64
else
np
.
int32
return
np
.
int64
if
vartype
==
tf
.
int64
else
np
.
int32
return
None
return
None
if
hasattr
(
val
,
'dtype'
):
if
hasattr
(
val
ue
,
'dtype'
):
vartype
=
var
.
value
()
.
dtype
vartype
=
var
.
value
()
.
dtype
if
vartype
!=
val
.
dtype
:
if
vartype
!=
val
ue
.
dtype
:
msg
=
"Variable {} has dtype {} but was given a value of dtype {}."
.
format
(
name
,
vartype
,
val
.
dtype
)
msg
=
"Variable {} has dtype {} but was given a value of dtype {}."
.
format
(
name
,
vartype
,
val
ue
.
dtype
)
newtype
=
upcast
(
var
.
dtype
.
base_dtype
,
val
.
dtype
)
newtype
=
upcast
(
var
.
dtype
.
base_dtype
,
val
ue
.
dtype
)
if
newtype
is
not
None
:
if
newtype
is
not
None
:
val
=
newtype
(
val
)
val
ue
=
newtype
(
value
)
logger
.
warn
(
msg
+
" Load it after casting!"
)
logger
.
warn
(
msg
+
" Load it after casting!"
)
else
:
else
:
assert
vartype
==
val
.
dtype
,
msg
assert
vartype
==
value
.
dtype
,
msg
try
:
return
value
var
.
load
(
val
)
except
tf
.
errors
.
InvalidArgumentError
:
logger
.
exc
(
"Cannot load this value to the variable {}"
.
format
(
name
))
def
update
(
self
,
prms
):
def
update
(
self
,
prms
):
"""
"""
...
@@ -103,10 +102,15 @@ class SessionUpdate(object):
...
@@ -103,10 +102,15 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update.
Any name in prms must be in the graph and in vars_to_update.
"""
"""
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
fetches
=
[]
feeds
=
{}
for
name
,
value
in
six
.
iteritems
(
prms
):
for
name
,
value
in
six
.
iteritems
(
prms
):
assert
name
in
self
.
name_map
assert
name
in
self
.
name_map
v
=
self
.
name_map
[
name
]
var
=
self
.
name_map
[
name
]
SessionUpdate
.
load_value_to_var
(
v
,
value
)
fetches
.
append
(
var
.
initializer
)
# This is the implementation of `var.load`
feeds
[
var
.
initializer
.
inputs
[
1
]]
=
SessionUpdate
.
relaxed_value_for_var
(
value
,
var
)
self
.
sess
.
run
(
fetches
,
feed_dict
=
feeds
)
def
dump_session_params
(
path
):
def
dump_session_params
(
path
):
...
...
tensorpack/utils/timer.py
View file @
78595e71
...
@@ -40,11 +40,13 @@ def timed_operation(msg, log_start=False):
...
@@ -40,11 +40,13 @@ def timed_operation(msg, log_start=False):
Good stuff finished, time:1sec.
Good stuff finished, time:1sec.
"""
"""
assert
len
(
msg
)
if
log_start
:
if
log_start
:
logger
.
info
(
'Start {} ...'
.
format
(
msg
))
logger
.
info
(
'Start {} ...'
.
format
(
msg
))
start
=
timer
()
start
=
timer
()
yield
yield
logger
.
info
(
'{} finished, time:{:.4f}sec.'
.
format
(
msg
=
msg
[
0
]
.
upper
()
+
msg
[
1
:]
logger
.
info
(
'{} finished, time:{:.4f} sec.'
.
format
(
msg
,
timer
()
-
start
))
msg
,
timer
()
-
start
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment