Commit 683e43ff authored by Yuxin Wu's avatar Yuxin Wu

Fix reference leak in call_only_once, use memoized_method for methods. (fix #969)

parent 69b68b26
......@@ -8,7 +8,7 @@ from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import (
Conv2D, FullyConnected, layer_register)
from tensorpack.utils.argtools import memoized
from tensorpack.utils.argtools import memoized_method
from basemodel import GroupNorm
from utils.box_ops import pairwise_iou
......@@ -316,22 +316,22 @@ class BoxProposals(object):
if k != 'self' and v is not None:
setattr(self, k, v)
@memoized
@memoized_method
def fg_inds(self):
""" Returns: #fg indices in [0, N-1] """
return tf.reshape(tf.where(self.labels > 0), [-1], name='fg_inds')
@memoized
@memoized_method
def fg_boxes(self):
""" Returns: #fg x4"""
return tf.gather(self.boxes, self.fg_inds(), name='fg_boxes')
@memoized
@memoized_method
def fg_labels(self):
""" Returns: #fg"""
return tf.gather(self.labels, self.fg_inds(), name='fg_labels')
@memoized
@memoized_method
def matched_gt_boxes(self):
""" Returns: #fg x 4"""
return tf.gather(self.gt_boxes, self.fg_inds_wrt_gt)
......@@ -354,12 +354,12 @@ class FastRCNNHead(object):
setattr(self, k, v)
self._bbox_class_agnostic = int(box_logits.shape[1]) == 1
@memoized
@memoized_method
def fg_box_logits(self):
""" Returns: #fg x ? x 4 """
return tf.gather(self.box_logits, self.proposals.fg_inds(), name='fg_box_logits')
@memoized
@memoized_method
def losses(self):
encoded_fg_gt_boxes = encode_bbox_target(
self.proposals.matched_gt_boxes(),
......@@ -369,7 +369,7 @@ class FastRCNNHead(object):
encoded_fg_gt_boxes, self.fg_box_logits()
)
@memoized
@memoized_method
def decoded_output_boxes(self):
""" Returns: N x #class x 4 """
anchors = tf.tile(tf.expand_dims(self.proposals.boxes, 1),
......@@ -380,17 +380,17 @@ class FastRCNNHead(object):
)
return decoded_boxes
@memoized
@memoized_method
def decoded_output_boxes_for_true_label(self):
""" Returns: Nx4 decoded boxes """
return self._decoded_output_boxes_for_label(self.proposals.labels)
@memoized
@memoized_method
def decoded_output_boxes_for_predicted_label(self):
""" Returns: Nx4 decoded boxes """
return self._decoded_output_boxes_for_label(self.predicted_labels())
@memoized
@memoized_method
def decoded_output_boxes_for_label(self, labels):
assert not self._bbox_class_agnostic
indices = tf.stack([
......@@ -404,7 +404,7 @@ class FastRCNNHead(object):
)
return decoded
@memoized
@memoized_method
def decoded_output_boxes_class_agnostic(self):
""" Returns: Nx4 """
assert self._bbox_class_agnostic
......@@ -415,12 +415,12 @@ class FastRCNNHead(object):
)
return decoded
@memoized
@memoized_method
def output_scores(self, name=None):
""" Returns: N x #class scores, summed to one for each box."""
return tf.nn.softmax(self.label_logits, name=name)
@memoized
@memoized_method
def predicted_labels(self):
""" Returns: N ints """
return tf.argmax(self.label_logits, axis=1, name='predicted_labels')
......@@ -9,7 +9,7 @@ from tensorpack import (TowerTrainer, StagingInput,
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized
from tensorpack.utils.argtools import memoized_method
from tensorpack.utils.develop import deprecated
......@@ -68,7 +68,7 @@ class GANModelDesc(ModelDescBase):
"""
pass
@memoized
@memoized_method
def get_optimizer(self):
return self.optimizer()
......
......@@ -6,7 +6,7 @@ from collections import namedtuple
import tensorflow as tf
from ..utils import logger
from ..utils.argtools import memoized
from ..utils.argtools import memoized_method
from ..utils.develop import log_deprecated
from ..tfutils.tower import get_current_tower_context
from ..models.regularize import regularize_cost_from_collection
......@@ -90,7 +90,7 @@ class ModelDescBase(object):
Base class for a model description.
"""
@memoized
@memoized_method
def get_inputs_desc(self):
"""
Returns:
......@@ -207,7 +207,7 @@ class ModelDesc(ModelDescBase):
def _get_cost(self, *args):
return self.cost
@memoized
@memoized_method
def get_optimizer(self):
"""
Return the memoized optimizer returned by `optimizer()`.
......
......@@ -8,7 +8,7 @@ from six.moves import zip
from contextlib import contextmanager
import tensorflow as tf
from ..utils.argtools import memoized, call_only_once
from ..utils.argtools import memoized_method, call_only_once
from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
......@@ -109,7 +109,7 @@ class InputSource(object):
"""
return self._setup_done
@memoized
@memoized_method
def get_callbacks(self):
"""
An InputSource might need some extra maintenance during training,
......
......@@ -10,7 +10,7 @@ if six.PY2:
else:
import functools
__all__ = ['map_arg', 'memoized', 'graph_memoized', 'shape2d', 'shape4d',
__all__ = ['map_arg', 'memoized', 'memoized_method', 'graph_memoized', 'shape2d', 'shape4d',
'memoized_ignoreargs', 'log_once', 'call_only_once']
......@@ -39,13 +39,17 @@ def map_arg(**maps):
memoized = functools.lru_cache(maxsize=None)
""" Alias to :func:`functools.lru_cache` """
""" Alias to :func:`functools.lru_cache`
WARNING: memoization will keep keys and values alive!
"""
def graph_memoized(func):
"""
Like memoized, but keep one cache per default graph.
"""
# TODO it keeps the graph alive
import tensorflow as tf
GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__'
......@@ -81,16 +85,6 @@ def memoized_ignoreargs(func):
return _MEMOIZED_NOARGS[func]
return wrapper
# _GLOBAL_MEMOIZED_CACHE = dict()
# def global_memoized(func):
# """ Make sure that the same `memoized` object is returned on different
# calls to global_memoized(func)
# """
# ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
# if ret is None:
# ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
# return ret
def shape2d(a):
"""
......@@ -152,9 +146,6 @@ def log_once(message, func='info'):
getattr(logger, func)(message)
_FUNC_CALLED = set()
def call_only_once(func):
"""
Decorate a method or property of a class, so that this method can only
......@@ -168,21 +159,52 @@ def call_only_once(func):
# fails if func is a property
assert func.__name__ in dir(self), "call_only_once can only be used on method or property!"
if not hasattr(self, '_CALL_ONLY_ONCE_CACHE'):
cache = self._CALL_ONLY_ONCE_CACHE = set()
else:
cache = self._CALL_ONLY_ONCE_CACHE
cls = type(self)
# cannot use ismethod(), because decorated method becomes a function
is_method = inspect.isfunction(getattr(cls, func.__name__))
key = (self, func)
assert key not in _FUNC_CALLED, \
assert func not in cache, \
"{} {}.{} can only be called once per object!".format(
'Method' if is_method else 'Property',
cls.__name__, func.__name__)
_FUNC_CALLED.add(key)
cache.add(func)
return func(*args, **kwargs)
return wrapper
def memoized_method(func):
"""
A decorator that performs memoization on methods. It stores the cache on the object instance itself.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
self = args[0]
assert func.__name__ in dir(self), "memoized_method can only be used on method!"
if not hasattr(self, '_MEMOIZED_CACHE'):
cache = self._MEMOIZED_CACHE = {}
else:
cache = self._MEMOIZED_CACHE
key = args[1:] + tuple(kwargs)
print(key)
ret = cache.get(key, None)
if ret is not None:
return ret
value = func(*args, **kwargs)
cache[key] = value
return value
return wrapper
if __name__ == '__main__':
class A():
def __init__(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment