Commit 931640c5 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

enable_arg_scope for a given function (#1035)

parent 4eaeed3f
......@@ -10,7 +10,8 @@ from inspect import getmembers, isfunction
from ..utils import logger
from .tower import get_current_tower_context
__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module']
__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module',
'enable_argscope_for_function']
_ArgScopeStack = []
......@@ -67,8 +68,21 @@ def get_arg_scope():
return defaultdict(dict)
def argscope_mapper(func, log_shape=True):
def enable_argscope_for_function(func, log_shape=True):
"""Decorator for function to support argscope
Example:
.. code-block:: python
from mylib import myfunc
myfunc = enable_argscope_for_function(myfunc)
Args:
func: function which should be decorated.
log_shape (bool): print input/output shapes of each function.
Returns:
The decorated function.
"""
@wraps(func)
def wrapped_func(*args, **kwargs):
......@@ -82,7 +96,8 @@ def argscope_mapper(func, log_shape=True):
if log_shape:
if ('tower' not in ctx.ns_name.lower()) or ctx.is_main_training_tower:
logger.info('%20s: %20s -> %20s' %
(name, in_tensor.shape.as_list(), out_tensor.shape.as_list()))
(name, in_tensor.shape.as_list(),
out_tensor.shape.as_list()))
return out_tensor
# argscope requires this property
......@@ -93,12 +108,19 @@ def argscope_mapper(func, log_shape=True):
def enable_argscope_for_module(module, log_shape=True):
"""
Overwrite all functions of a given module to support argscope.
Note that this function monkey-patches the module and therefore could have unexpected consequences.
Note that this function monkey-patches the module and therefore could
have unexpected consequences.
It has been only tested to work well with `tf.layers` module.
Example:
.. code-block:: python
import tensorflow as tf
enable_argscope_for_module(tf.layers)
Args:
log_shape (bool): print input/output shapes of each function when called.
log_shape (bool): print input/output shapes of each function.
"""
for name, obj in getmembers(module):
if isfunction(obj):
setattr(module, name, argscope_mapper(obj, log_shape=log_shape))
setattr(module, name, enable_argscope_for_function(obj,
log_shape=log_shape))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment