Commit 4f4794df authored by Yuxin Wu's avatar Yuxin Wu

use inspect.signature instead of getcallargs. also improve argscope with this

parent 9850edf5
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import inspect
from functools import wraps from functools import wraps
import six import six
import re import re
...@@ -14,7 +15,7 @@ from ..utils import logger ...@@ -14,7 +15,7 @@ from ..utils import logger
# make sure each layer is only logged once # make sure each layer is only logged once
_LAYER_LOGGED = set() _LAYER_LOGGED = set()
_LAYER_REGISTERED = {} _LAYER_REGISTRY = {}
__all__ = ['layer_register'] __all__ = ['layer_register']
...@@ -53,11 +54,11 @@ class VariableHolder(object): ...@@ -53,11 +54,11 @@ class VariableHolder(object):
def _register(name, func): def _register(name, func):
if name in _LAYER_REGISTERED: if name in _LAYER_REGISTRY:
raise ValueError("Layer named {} is already registered!".format(name)) raise ValueError("Layer named {} is already registered!".format(name))
if name in ['tf']: if name in ['tf']:
raise ValueError(logger.error("A layer cannot be named {}".format(name))) raise ValueError(logger.error("A layer cannot be named {}".format(name)))
_LAYER_REGISTERED[name] = func _LAYER_REGISTRY[name] = func
def get_registered_layer(name): def get_registered_layer(name):
...@@ -67,7 +68,7 @@ def get_registered_layer(name): ...@@ -67,7 +68,7 @@ def get_registered_layer(name):
Returns: Returns:
the wrapped layer function, or None if not registered. the wrapped layer function, or None if not registered.
""" """
return _LAYER_REGISTERED.get(name, None) return _LAYER_REGISTRY.get(name, None)
def disable_layer_logging(): def disable_layer_logging():
...@@ -124,10 +125,18 @@ def layer_register( ...@@ -124,10 +125,18 @@ def layer_register(
isinstance(inputs[0], (tf.Tensor, tf.Variable)))): isinstance(inputs[0], (tf.Tensor, tf.Variable)))):
raise ValueError("Invalid inputs to layer: " + str(inputs)) raise ValueError("Invalid inputs to layer: " + str(inputs))
# TODO use inspect.getcallargs to enhance? # use kwargs from current argument scope
# update from current argument scope
actual_args = copy.copy(get_arg_scope()[func.__name__]) actual_args = copy.copy(get_arg_scope()[func.__name__])
# explicit kwargs overwrite argscope
actual_args.update(kwargs) actual_args.update(kwargs)
# explicit positional args also override argscope
if six.PY2:
posargmap = inspect.getcallargs(func, *args)
else:
posargmap = inspect.signature(func).bind_partial(*args).arguments
for k in six.iterkeys(posargmap):
if k in actual_args:
del actual_args[k]
if name is not None: # use scope if name is not None: # use scope
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
......
...@@ -20,12 +20,17 @@ def map_arg(**maps): ...@@ -20,12 +20,17 @@ def map_arg(**maps):
Apply a mapping on certains argument before calling the original function. Apply a mapping on certains argument before calling the original function.
Args: Args:
maps (dict): {key: map_func} maps (dict): {argument_name: map_func}
""" """
def deco(func): def deco(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
argmap = inspect.getcallargs(func, *args, **kwargs) if six.PY2:
argmap = inspect.getcallargs(func, *args, **kwargs)
else:
# getcallargs was deprecated since 3.5
sig = inspect.signature(func)
argmap = sig.bind_partial(*args, **kwargs).arguments
for k, map_func in six.iteritems(maps): for k, map_func in six.iteritems(maps):
if k in argmap: if k in argmap:
argmap[k] = map_func(argmap[k]) argmap[k] = map_func(argmap[k])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment