Commit 683e43ff authored by Yuxin Wu's avatar Yuxin Wu

Fix reference leak in call_only_once, use memoized_method for methods. (fix #969)

parent 69b68b26
...@@ -8,7 +8,7 @@ from tensorpack.tfutils.argscope import argscope ...@@ -8,7 +8,7 @@ from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import ( from tensorpack.models import (
Conv2D, FullyConnected, layer_register) Conv2D, FullyConnected, layer_register)
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized_method
from basemodel import GroupNorm from basemodel import GroupNorm
from utils.box_ops import pairwise_iou from utils.box_ops import pairwise_iou
...@@ -316,22 +316,22 @@ class BoxProposals(object): ...@@ -316,22 +316,22 @@ class BoxProposals(object):
if k != 'self' and v is not None: if k != 'self' and v is not None:
setattr(self, k, v) setattr(self, k, v)
@memoized @memoized_method
def fg_inds(self): def fg_inds(self):
""" Returns: #fg indices in [0, N-1] """ """ Returns: #fg indices in [0, N-1] """
return tf.reshape(tf.where(self.labels > 0), [-1], name='fg_inds') return tf.reshape(tf.where(self.labels > 0), [-1], name='fg_inds')
@memoized @memoized_method
def fg_boxes(self): def fg_boxes(self):
""" Returns: #fg x4""" """ Returns: #fg x4"""
return tf.gather(self.boxes, self.fg_inds(), name='fg_boxes') return tf.gather(self.boxes, self.fg_inds(), name='fg_boxes')
@memoized @memoized_method
def fg_labels(self): def fg_labels(self):
""" Returns: #fg""" """ Returns: #fg"""
return tf.gather(self.labels, self.fg_inds(), name='fg_labels') return tf.gather(self.labels, self.fg_inds(), name='fg_labels')
@memoized @memoized_method
def matched_gt_boxes(self): def matched_gt_boxes(self):
""" Returns: #fg x 4""" """ Returns: #fg x 4"""
return tf.gather(self.gt_boxes, self.fg_inds_wrt_gt) return tf.gather(self.gt_boxes, self.fg_inds_wrt_gt)
...@@ -354,12 +354,12 @@ class FastRCNNHead(object): ...@@ -354,12 +354,12 @@ class FastRCNNHead(object):
setattr(self, k, v) setattr(self, k, v)
self._bbox_class_agnostic = int(box_logits.shape[1]) == 1 self._bbox_class_agnostic = int(box_logits.shape[1]) == 1
@memoized @memoized_method
def fg_box_logits(self): def fg_box_logits(self):
""" Returns: #fg x ? x 4 """ """ Returns: #fg x ? x 4 """
return tf.gather(self.box_logits, self.proposals.fg_inds(), name='fg_box_logits') return tf.gather(self.box_logits, self.proposals.fg_inds(), name='fg_box_logits')
@memoized @memoized_method
def losses(self): def losses(self):
encoded_fg_gt_boxes = encode_bbox_target( encoded_fg_gt_boxes = encode_bbox_target(
self.proposals.matched_gt_boxes(), self.proposals.matched_gt_boxes(),
...@@ -369,7 +369,7 @@ class FastRCNNHead(object): ...@@ -369,7 +369,7 @@ class FastRCNNHead(object):
encoded_fg_gt_boxes, self.fg_box_logits() encoded_fg_gt_boxes, self.fg_box_logits()
) )
@memoized @memoized_method
def decoded_output_boxes(self): def decoded_output_boxes(self):
""" Returns: N x #class x 4 """ """ Returns: N x #class x 4 """
anchors = tf.tile(tf.expand_dims(self.proposals.boxes, 1), anchors = tf.tile(tf.expand_dims(self.proposals.boxes, 1),
...@@ -380,17 +380,17 @@ class FastRCNNHead(object): ...@@ -380,17 +380,17 @@ class FastRCNNHead(object):
) )
return decoded_boxes return decoded_boxes
@memoized @memoized_method
def decoded_output_boxes_for_true_label(self): def decoded_output_boxes_for_true_label(self):
""" Returns: Nx4 decoded boxes """ """ Returns: Nx4 decoded boxes """
return self._decoded_output_boxes_for_label(self.proposals.labels) return self._decoded_output_boxes_for_label(self.proposals.labels)
@memoized @memoized_method
def decoded_output_boxes_for_predicted_label(self): def decoded_output_boxes_for_predicted_label(self):
""" Returns: Nx4 decoded boxes """ """ Returns: Nx4 decoded boxes """
return self._decoded_output_boxes_for_label(self.predicted_labels()) return self._decoded_output_boxes_for_label(self.predicted_labels())
@memoized @memoized_method
def decoded_output_boxes_for_label(self, labels): def decoded_output_boxes_for_label(self, labels):
assert not self._bbox_class_agnostic assert not self._bbox_class_agnostic
indices = tf.stack([ indices = tf.stack([
...@@ -404,7 +404,7 @@ class FastRCNNHead(object): ...@@ -404,7 +404,7 @@ class FastRCNNHead(object):
) )
return decoded return decoded
@memoized @memoized_method
def decoded_output_boxes_class_agnostic(self): def decoded_output_boxes_class_agnostic(self):
""" Returns: Nx4 """ """ Returns: Nx4 """
assert self._bbox_class_agnostic assert self._bbox_class_agnostic
...@@ -415,12 +415,12 @@ class FastRCNNHead(object): ...@@ -415,12 +415,12 @@ class FastRCNNHead(object):
) )
return decoded return decoded
@memoized @memoized_method
def output_scores(self, name=None): def output_scores(self, name=None):
""" Returns: N x #class scores, summed to one for each box.""" """ Returns: N x #class scores, summed to one for each box."""
return tf.nn.softmax(self.label_logits, name=name) return tf.nn.softmax(self.label_logits, name=name)
@memoized @memoized_method
def predicted_labels(self): def predicted_labels(self):
""" Returns: N ints """ """ Returns: N ints """
return tf.argmax(self.label_logits, axis=1, name='predicted_labels') return tf.argmax(self.label_logits, axis=1, name='predicted_labels')
...@@ -9,7 +9,7 @@ from tensorpack import (TowerTrainer, StagingInput, ...@@ -9,7 +9,7 @@ from tensorpack import (TowerTrainer, StagingInput,
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized_method
from tensorpack.utils.develop import deprecated from tensorpack.utils.develop import deprecated
...@@ -68,7 +68,7 @@ class GANModelDesc(ModelDescBase): ...@@ -68,7 +68,7 @@ class GANModelDesc(ModelDescBase):
""" """
pass pass
@memoized @memoized_method
def get_optimizer(self): def get_optimizer(self):
return self.optimizer() return self.optimizer()
......
...@@ -6,7 +6,7 @@ from collections import namedtuple ...@@ -6,7 +6,7 @@ from collections import namedtuple
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized_method
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
...@@ -90,7 +90,7 @@ class ModelDescBase(object): ...@@ -90,7 +90,7 @@ class ModelDescBase(object):
Base class for a model description. Base class for a model description.
""" """
@memoized @memoized_method
def get_inputs_desc(self): def get_inputs_desc(self):
""" """
Returns: Returns:
...@@ -207,7 +207,7 @@ class ModelDesc(ModelDescBase): ...@@ -207,7 +207,7 @@ class ModelDesc(ModelDescBase):
def _get_cost(self, *args): def _get_cost(self, *args):
return self.cost return self.cost
@memoized @memoized_method
def get_optimizer(self): def get_optimizer(self):
""" """
Return the memoized optimizer returned by `optimizer()`. Return the memoized optimizer returned by `optimizer()`.
......
...@@ -8,7 +8,7 @@ from six.moves import zip ...@@ -8,7 +8,7 @@ from six.moves import zip
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from ..utils.argtools import memoized, call_only_once from ..utils.argtools import memoized_method, call_only_once
from ..callbacks.base import CallbackFactory from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
...@@ -109,7 +109,7 @@ class InputSource(object): ...@@ -109,7 +109,7 @@ class InputSource(object):
""" """
return self._setup_done return self._setup_done
@memoized @memoized_method
def get_callbacks(self): def get_callbacks(self):
""" """
An InputSource might need some extra maintenance during training, An InputSource might need some extra maintenance during training,
......
...@@ -10,7 +10,7 @@ if six.PY2: ...@@ -10,7 +10,7 @@ if six.PY2:
else: else:
import functools import functools
__all__ = ['map_arg', 'memoized', 'graph_memoized', 'shape2d', 'shape4d', __all__ = ['map_arg', 'memoized', 'memoized_method', 'graph_memoized', 'shape2d', 'shape4d',
'memoized_ignoreargs', 'log_once', 'call_only_once'] 'memoized_ignoreargs', 'log_once', 'call_only_once']
...@@ -39,13 +39,17 @@ def map_arg(**maps): ...@@ -39,13 +39,17 @@ def map_arg(**maps):
memoized = functools.lru_cache(maxsize=None) memoized = functools.lru_cache(maxsize=None)
""" Alias to :func:`functools.lru_cache` """ """ Alias to :func:`functools.lru_cache`
WARNING: memoization will keep keys and values alive!
"""
def graph_memoized(func): def graph_memoized(func):
""" """
Like memoized, but keep one cache per default graph. Like memoized, but keep one cache per default graph.
""" """
# TODO it keeps the graph alive
import tensorflow as tf import tensorflow as tf
GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__' GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__'
...@@ -81,16 +85,6 @@ def memoized_ignoreargs(func): ...@@ -81,16 +85,6 @@ def memoized_ignoreargs(func):
return _MEMOIZED_NOARGS[func] return _MEMOIZED_NOARGS[func]
return wrapper return wrapper
# _GLOBAL_MEMOIZED_CACHE = dict()
# def global_memoized(func):
# """ Make sure that the same `memoized` object is returned on different
# calls to global_memoized(func)
# """
# ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
# if ret is None:
# ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
# return ret
def shape2d(a): def shape2d(a):
""" """
...@@ -152,9 +146,6 @@ def log_once(message, func='info'): ...@@ -152,9 +146,6 @@ def log_once(message, func='info'):
getattr(logger, func)(message) getattr(logger, func)(message)
_FUNC_CALLED = set()
def call_only_once(func): def call_only_once(func):
""" """
Decorate a method or property of a class, so that this method can only Decorate a method or property of a class, so that this method can only
...@@ -168,21 +159,52 @@ def call_only_once(func): ...@@ -168,21 +159,52 @@ def call_only_once(func):
# fails if func is a property # fails if func is a property
assert func.__name__ in dir(self), "call_only_once can only be used on method or property!" assert func.__name__ in dir(self), "call_only_once can only be used on method or property!"
if not hasattr(self, '_CALL_ONLY_ONCE_CACHE'):
cache = self._CALL_ONLY_ONCE_CACHE = set()
else:
cache = self._CALL_ONLY_ONCE_CACHE
cls = type(self) cls = type(self)
# cannot use ismethod(), because decorated method becomes a function # cannot use ismethod(), because decorated method becomes a function
is_method = inspect.isfunction(getattr(cls, func.__name__)) is_method = inspect.isfunction(getattr(cls, func.__name__))
key = (self, func) assert func not in cache, \
assert key not in _FUNC_CALLED, \
"{} {}.{} can only be called once per object!".format( "{} {}.{} can only be called once per object!".format(
'Method' if is_method else 'Property', 'Method' if is_method else 'Property',
cls.__name__, func.__name__) cls.__name__, func.__name__)
_FUNC_CALLED.add(key) cache.add(func)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
def memoized_method(func):
"""
A decorator that performs memoization on methods. It stores the cache on the object instance itself.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
self = args[0]
assert func.__name__ in dir(self), "memoized_method can only be used on method!"
if not hasattr(self, '_MEMOIZED_CACHE'):
cache = self._MEMOIZED_CACHE = {}
else:
cache = self._MEMOIZED_CACHE
key = args[1:] + tuple(kwargs)
print(key)
ret = cache.get(key, None)
if ret is not None:
return ret
value = func(*args, **kwargs)
cache[key] = value
return value
return wrapper
if __name__ == '__main__': if __name__ == '__main__':
class A(): class A():
def __init__(self): def __init__(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment