Commit 900a7eb0 authored by Yuxin Wu's avatar Yuxin Wu

add graph_memoized and fix #276

parent 9b710110
......@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from tensorpack.utils.argtools import memoized
from tensorpack.utils.argtools import graph_memoized
@memoized
@graph_memoized
def get_dorefa(bitW, bitA, bitG):
"""
return the three quantization functions fw, fa, fg, for weights, activations and gradients respectively
......
......@@ -6,14 +6,14 @@ import tensorflow as tf
import re
from ..utils import logger
from ..utils.argtools import memoized
from ..utils.argtools import graph_memoized
from ..tfutils.tower import get_current_tower_context
from .common import layer_register
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
@memoized
@graph_memoized
def _log_regularizer(name):
logger.info("Apply regularizer for {}".format(name))
......
......@@ -5,17 +5,14 @@
import tensorflow as tf
from six.moves import map
from ..utils.argtools import graph_memoized
from ..utils.naming import (
GLOBAL_STEP_VAR_NAME,
GLOBAL_STEP_OP_NAME)
from ..utils.naming import GLOBAL_STEP_OP_NAME
__all__ = ['get_default_sess_config',
'get_global_step_value',
'get_global_step_var',
#'get_local_step_var',
'get_op_tensor_name',
'get_tensors_by_names',
'get_op_or_tensor_by_name',
......@@ -51,15 +48,13 @@ def get_default_sess_config(mem_fraction=0.99):
return conf
@graph_memoized
def get_global_step_var():
"""
Returns:
tf.Tensor: the global_step variable in the current graph. create if
doesn't exist.
"""
try:
return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError:
scope = tf.get_variable_scope()
assert scope.name == '', \
"The global_step variable should be created under the root variable scope!"
......
......@@ -38,6 +38,26 @@ memoized = functools.lru_cache(maxsize=None)
""" Alias to :func:`functools.lru_cache` """
def graph_memoized(func):
"""
Like memoized, but keep one cache per default graph.
"""
import tensorflow as tf
GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__'
@memoized
def func_with_graph_arg(*args, **kwargs):
kwargs.pop(GRAPH_ARG_NAME)
return func(*args, **kwargs)
def wrapper(*args, **kwargs):
assert GRAPH_ARG_NAME not in kwargs, "No Way!!"
graph = tf.get_default_graph()
kwargs[GRAPH_ARG_NAME] = graph
return func_with_graph_arg(*args, **kwargs)
return wrapper
_MEMOIZED_NOARGS = {}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment