Commit 900a7eb0 authored by Yuxin Wu's avatar Yuxin Wu

add graph_memoized and fix #276

parent 9b710110
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import graph_memoized
@memoized @graph_memoized
def get_dorefa(bitW, bitA, bitG): def get_dorefa(bitW, bitA, bitG):
""" """
return the three quantization functions fw, fa, fg, for weights, activations and gradients respectively return the three quantization functions fw, fa, fg, for weights, activations and gradients respectively
......
...@@ -6,14 +6,14 @@ import tensorflow as tf ...@@ -6,14 +6,14 @@ import tensorflow as tf
import re import re
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import graph_memoized
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from .common import layer_register from .common import layer_register
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout'] __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
@memoized @graph_memoized
def _log_regularizer(name): def _log_regularizer(name):
logger.info("Apply regularizer for {}".format(name)) logger.info("Apply regularizer for {}".format(name))
......
...@@ -5,17 +5,14 @@ ...@@ -5,17 +5,14 @@
import tensorflow as tf import tensorflow as tf
from six.moves import map from six.moves import map
from ..utils.argtools import graph_memoized
from ..utils.naming import ( from ..utils.naming import GLOBAL_STEP_OP_NAME
GLOBAL_STEP_VAR_NAME,
GLOBAL_STEP_OP_NAME)
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
#'get_local_step_var',
'get_op_tensor_name', 'get_op_tensor_name',
'get_tensors_by_names', 'get_tensors_by_names',
'get_op_or_tensor_by_name', 'get_op_or_tensor_by_name',
...@@ -51,24 +48,22 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -51,24 +48,22 @@ def get_default_sess_config(mem_fraction=0.99):
return conf return conf
@graph_memoized
def get_global_step_var(): def get_global_step_var():
""" """
Returns: Returns:
tf.Tensor: the global_step variable in the current graph. create if tf.Tensor: the global_step variable in the current graph. create if
doesn't exist. doesn't exist.
""" """
try: scope = tf.get_variable_scope()
return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) assert scope.name == '', \
except KeyError: "The global_step variable should be created under the root variable scope!"
scope = tf.get_variable_scope() with tf.variable_scope(scope, reuse=False), \
assert scope.name == '', \ tf.name_scope(None):
"The global_step variable should be created under the root variable scope!" var = tf.get_variable(GLOBAL_STEP_OP_NAME,
with tf.variable_scope(scope, reuse=False), \ initializer=tf.constant(0, dtype=tf.int64),
tf.name_scope(None): trainable=False, dtype=tf.int64)
var = tf.get_variable(GLOBAL_STEP_OP_NAME, return var
initializer=tf.constant(0, dtype=tf.int64),
trainable=False, dtype=tf.int64)
return var
def get_global_step_value(): def get_global_step_value():
......
...@@ -38,6 +38,26 @@ memoized = functools.lru_cache(maxsize=None) ...@@ -38,6 +38,26 @@ memoized = functools.lru_cache(maxsize=None)
""" Alias to :func:`functools.lru_cache` """ """ Alias to :func:`functools.lru_cache` """
def graph_memoized(func):
"""
Like memoized, but keep one cache per default graph.
"""
import tensorflow as tf
GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__'
@memoized
def func_with_graph_arg(*args, **kwargs):
kwargs.pop(GRAPH_ARG_NAME)
return func(*args, **kwargs)
def wrapper(*args, **kwargs):
assert GRAPH_ARG_NAME not in kwargs, "No Way!!"
graph = tf.get_default_graph()
kwargs[GRAPH_ARG_NAME] = graph
return func_with_graph_arg(*args, **kwargs)
return wrapper
_MEMOIZED_NOARGS = {} _MEMOIZED_NOARGS = {}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment