Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
683e43ff
Commit
683e43ff
authored
Nov 06, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix reference leak in call_only_once, use memoized_method for methods. (fix #969)
parent
69b68b26
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
61 additions
and
39 deletions
+61
-39
examples/FasterRCNN/model_frcnn.py
examples/FasterRCNN/model_frcnn.py
+14
-14
examples/GAN/GAN.py
examples/GAN/GAN.py
+2
-2
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+3
-3
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+2
-2
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+40
-18
No files found.
examples/FasterRCNN/model_frcnn.py
View file @
683e43ff
...
...
@@ -8,7 +8,7 @@ from tensorpack.tfutils.argscope import argscope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.models
import
(
Conv2D
,
FullyConnected
,
layer_register
)
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.argtools
import
memoized
_method
from
basemodel
import
GroupNorm
from
utils.box_ops
import
pairwise_iou
...
...
@@ -316,22 +316,22 @@ class BoxProposals(object):
if
k
!=
'self'
and
v
is
not
None
:
setattr
(
self
,
k
,
v
)
@
memoized
@
memoized
_method
def
fg_inds
(
self
):
""" Returns: #fg indices in [0, N-1] """
return
tf
.
reshape
(
tf
.
where
(
self
.
labels
>
0
),
[
-
1
],
name
=
'fg_inds'
)
@
memoized
@
memoized
_method
def
fg_boxes
(
self
):
""" Returns: #fg x4"""
return
tf
.
gather
(
self
.
boxes
,
self
.
fg_inds
(),
name
=
'fg_boxes'
)
@
memoized
@
memoized
_method
def
fg_labels
(
self
):
""" Returns: #fg"""
return
tf
.
gather
(
self
.
labels
,
self
.
fg_inds
(),
name
=
'fg_labels'
)
@
memoized
@
memoized
_method
def
matched_gt_boxes
(
self
):
""" Returns: #fg x 4"""
return
tf
.
gather
(
self
.
gt_boxes
,
self
.
fg_inds_wrt_gt
)
...
...
@@ -354,12 +354,12 @@ class FastRCNNHead(object):
setattr
(
self
,
k
,
v
)
self
.
_bbox_class_agnostic
=
int
(
box_logits
.
shape
[
1
])
==
1
@
memoized
@
memoized
_method
def
fg_box_logits
(
self
):
""" Returns: #fg x ? x 4 """
return
tf
.
gather
(
self
.
box_logits
,
self
.
proposals
.
fg_inds
(),
name
=
'fg_box_logits'
)
@
memoized
@
memoized
_method
def
losses
(
self
):
encoded_fg_gt_boxes
=
encode_bbox_target
(
self
.
proposals
.
matched_gt_boxes
(),
...
...
@@ -369,7 +369,7 @@ class FastRCNNHead(object):
encoded_fg_gt_boxes
,
self
.
fg_box_logits
()
)
@
memoized
@
memoized
_method
def
decoded_output_boxes
(
self
):
""" Returns: N x #class x 4 """
anchors
=
tf
.
tile
(
tf
.
expand_dims
(
self
.
proposals
.
boxes
,
1
),
...
...
@@ -380,17 +380,17 @@ class FastRCNNHead(object):
)
return
decoded_boxes
@
memoized
@
memoized
_method
def
decoded_output_boxes_for_true_label
(
self
):
""" Returns: Nx4 decoded boxes """
return
self
.
_decoded_output_boxes_for_label
(
self
.
proposals
.
labels
)
@
memoized
@
memoized
_method
def
decoded_output_boxes_for_predicted_label
(
self
):
""" Returns: Nx4 decoded boxes """
return
self
.
_decoded_output_boxes_for_label
(
self
.
predicted_labels
())
@
memoized
@
memoized
_method
def
decoded_output_boxes_for_label
(
self
,
labels
):
assert
not
self
.
_bbox_class_agnostic
indices
=
tf
.
stack
([
...
...
@@ -404,7 +404,7 @@ class FastRCNNHead(object):
)
return
decoded
@
memoized
@
memoized
_method
def
decoded_output_boxes_class_agnostic
(
self
):
""" Returns: Nx4 """
assert
self
.
_bbox_class_agnostic
...
...
@@ -415,12 +415,12 @@ class FastRCNNHead(object):
)
return
decoded
@
memoized
@
memoized
_method
def
output_scores
(
self
,
name
=
None
):
""" Returns: N x #class scores, summed to one for each box."""
return
tf
.
nn
.
softmax
(
self
.
label_logits
,
name
=
name
)
@
memoized
@
memoized
_method
def
predicted_labels
(
self
):
""" Returns: N ints """
return
tf
.
argmax
(
self
.
label_logits
,
axis
=
1
,
name
=
'predicted_labels'
)
examples/GAN/GAN.py
View file @
683e43ff
...
...
@@ -9,7 +9,7 @@ from tensorpack import (TowerTrainer, StagingInput,
from
tensorpack.tfutils.tower
import
TowerContext
,
TowerFuncWrapper
from
tensorpack.graph_builder
import
DataParallelBuilder
,
LeastLoadedDeviceSetter
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.argtools
import
memoized
_method
from
tensorpack.utils.develop
import
deprecated
...
...
@@ -68,7 +68,7 @@ class GANModelDesc(ModelDescBase):
"""
pass
@
memoized
@
memoized
_method
def
get_optimizer
(
self
):
return
self
.
optimizer
()
...
...
tensorpack/graph_builder/model_desc.py
View file @
683e43ff
...
...
@@ -6,7 +6,7 @@ from collections import namedtuple
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
_method
from
..utils.develop
import
log_deprecated
from
..tfutils.tower
import
get_current_tower_context
from
..models.regularize
import
regularize_cost_from_collection
...
...
@@ -90,7 +90,7 @@ class ModelDescBase(object):
Base class for a model description.
"""
@
memoized
@
memoized
_method
def
get_inputs_desc
(
self
):
"""
Returns:
...
...
@@ -207,7 +207,7 @@ class ModelDesc(ModelDescBase):
def
_get_cost
(
self
,
*
args
):
return
self
.
cost
@
memoized
@
memoized
_method
def
get_optimizer
(
self
):
"""
Return the memoized optimizer returned by `optimizer()`.
...
...
tensorpack/input_source/input_source_base.py
View file @
683e43ff
...
...
@@ -8,7 +8,7 @@ from six.moves import zip
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..utils.argtools
import
memoized
,
call_only_once
from
..utils.argtools
import
memoized
_method
,
call_only_once
from
..callbacks.base
import
CallbackFactory
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
...
...
@@ -109,7 +109,7 @@ class InputSource(object):
"""
return
self
.
_setup_done
@
memoized
@
memoized
_method
def
get_callbacks
(
self
):
"""
An InputSource might need some extra maintenance during training,
...
...
tensorpack/utils/argtools.py
View file @
683e43ff
...
...
@@ -10,7 +10,7 @@ if six.PY2:
else
:
import
functools
__all__
=
[
'map_arg'
,
'memoized'
,
'graph_memoized'
,
'shape2d'
,
'shape4d'
,
__all__
=
[
'map_arg'
,
'memoized'
,
'
memoized_method'
,
'
graph_memoized'
,
'shape2d'
,
'shape4d'
,
'memoized_ignoreargs'
,
'log_once'
,
'call_only_once'
]
...
...
@@ -39,13 +39,17 @@ def map_arg(**maps):
memoized
=
functools
.
lru_cache
(
maxsize
=
None
)
""" Alias to :func:`functools.lru_cache` """
""" Alias to :func:`functools.lru_cache`
WARNING: memoization will keep keys and values alive!
"""
def
graph_memoized
(
func
):
"""
Like memoized, but keep one cache per default graph.
"""
# TODO it keeps the graph alive
import
tensorflow
as
tf
GRAPH_ARG_NAME
=
'__IMPOSSIBLE_NAME_FOR_YOU__'
...
...
@@ -81,16 +85,6 @@ def memoized_ignoreargs(func):
return
_MEMOIZED_NOARGS
[
func
]
return
wrapper
# _GLOBAL_MEMOIZED_CACHE = dict()
# def global_memoized(func):
# """ Make sure that the same `memoized` object is returned on different
# calls to global_memoized(func)
# """
# ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
# if ret is None:
# ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
# return ret
def
shape2d
(
a
):
"""
...
...
@@ -152,9 +146,6 @@ def log_once(message, func='info'):
getattr
(
logger
,
func
)(
message
)
_FUNC_CALLED
=
set
()
def
call_only_once
(
func
):
"""
Decorate a method or property of a class, so that this method can only
...
...
@@ -168,21 +159,52 @@ def call_only_once(func):
# fails if func is a property
assert
func
.
__name__
in
dir
(
self
),
"call_only_once can only be used on method or property!"
if
not
hasattr
(
self
,
'_CALL_ONLY_ONCE_CACHE'
):
cache
=
self
.
_CALL_ONLY_ONCE_CACHE
=
set
()
else
:
cache
=
self
.
_CALL_ONLY_ONCE_CACHE
cls
=
type
(
self
)
# cannot use ismethod(), because decorated method becomes a function
is_method
=
inspect
.
isfunction
(
getattr
(
cls
,
func
.
__name__
))
key
=
(
self
,
func
)
assert
key
not
in
_FUNC_CALLED
,
\
assert
func
not
in
cache
,
\
"{} {}.{} can only be called once per object!"
.
format
(
'Method'
if
is_method
else
'Property'
,
cls
.
__name__
,
func
.
__name__
)
_FUNC_CALLED
.
add
(
key
)
cache
.
add
(
func
)
return
func
(
*
args
,
**
kwargs
)
return
wrapper
def
memoized_method
(
func
):
"""
A decorator that performs memoization on methods. It stores the cache on the object instance itself.
"""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
self
=
args
[
0
]
assert
func
.
__name__
in
dir
(
self
),
"memoized_method can only be used on method!"
if
not
hasattr
(
self
,
'_MEMOIZED_CACHE'
):
cache
=
self
.
_MEMOIZED_CACHE
=
{}
else
:
cache
=
self
.
_MEMOIZED_CACHE
key
=
args
[
1
:]
+
tuple
(
kwargs
)
print
(
key
)
ret
=
cache
.
get
(
key
,
None
)
if
ret
is
not
None
:
return
ret
value
=
func
(
*
args
,
**
kwargs
)
cache
[
key
]
=
value
return
value
return
wrapper
if
__name__
==
'__main__'
:
class
A
():
def
__init__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment