Commit 931640c5 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

enable_arg_scope for a given function (#1035)

parent 4eaeed3f
...@@ -10,7 +10,8 @@ from inspect import getmembers, isfunction ...@@ -10,7 +10,8 @@ from inspect import getmembers, isfunction
from ..utils import logger from ..utils import logger
from .tower import get_current_tower_context from .tower import get_current_tower_context
__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module'] __all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module',
'enable_argscope_for_function']
_ArgScopeStack = [] _ArgScopeStack = []
...@@ -67,8 +68,21 @@ def get_arg_scope(): ...@@ -67,8 +68,21 @@ def get_arg_scope():
return defaultdict(dict) return defaultdict(dict)
def argscope_mapper(func, log_shape=True): def enable_argscope_for_function(func, log_shape=True):
"""Decorator for function to support argscope """Decorator for function to support argscope
Example:
.. code-block:: python
from mylib import myfunc
myfunc = enable_argscope_for_function(myfunc)
Args:
func: function which should be decorated.
log_shape (bool): print input/output shapes of each function.
Returns:
The decorated function.
""" """
@wraps(func) @wraps(func)
def wrapped_func(*args, **kwargs): def wrapped_func(*args, **kwargs):
...@@ -82,7 +96,8 @@ def argscope_mapper(func, log_shape=True): ...@@ -82,7 +96,8 @@ def argscope_mapper(func, log_shape=True):
if log_shape: if log_shape:
if ('tower' not in ctx.ns_name.lower()) or ctx.is_main_training_tower: if ('tower' not in ctx.ns_name.lower()) or ctx.is_main_training_tower:
logger.info('%20s: %20s -> %20s' % logger.info('%20s: %20s -> %20s' %
(name, in_tensor.shape.as_list(), out_tensor.shape.as_list())) (name, in_tensor.shape.as_list(),
out_tensor.shape.as_list()))
return out_tensor return out_tensor
# argscope requires this property # argscope requires this property
...@@ -93,12 +108,19 @@ def argscope_mapper(func, log_shape=True): ...@@ -93,12 +108,19 @@ def argscope_mapper(func, log_shape=True):
def enable_argscope_for_module(module, log_shape=True): def enable_argscope_for_module(module, log_shape=True):
""" """
Overwrite all functions of a given module to support argscope. Overwrite all functions of a given module to support argscope.
Note that this function monkey-patches the module and therefore could have unexpected consequences. Note that this function monkey-patches the module and therefore could
have unexpected consequences.
It has been only tested to work well with `tf.layers` module. It has been only tested to work well with `tf.layers` module.
Example:
.. code-block:: python
import tensorflow as tf
enable_argscope_for_module(tf.layers)
Args: Args:
log_shape (bool): print input/output shapes of each function when called. log_shape (bool): print input/output shapes of each function.
""" """
for name, obj in getmembers(module): for name, obj in getmembers(module):
if isfunction(obj): if isfunction(obj):
setattr(module, name, argscope_mapper(obj, log_shape=log_shape)) setattr(module, name, enable_argscope_for_function(obj,
log_shape=log_shape))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment