Commit 4f4794df authored by Yuxin Wu's avatar Yuxin Wu

use inspect.signature instead of getcallargs. also improve argscope with this

parent 9850edf5
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import inspect
from functools import wraps
import six
import re
......@@ -14,7 +15,7 @@ from ..utils import logger
# make sure each layer is only logged once
_LAYER_LOGGED = set()
_LAYER_REGISTERED = {}
_LAYER_REGISTRY = {}
__all__ = ['layer_register']
......@@ -53,11 +54,11 @@ class VariableHolder(object):
def _register(name, func):
if name in _LAYER_REGISTERED:
if name in _LAYER_REGISTRY:
raise ValueError("Layer named {} is already registered!".format(name))
if name in ['tf']:
raise ValueError(logger.error("A layer cannot be named {}".format(name)))
_LAYER_REGISTERED[name] = func
_LAYER_REGISTRY[name] = func
def get_registered_layer(name):
......@@ -67,7 +68,7 @@ def get_registered_layer(name):
Returns:
the wrapped layer function, or None if not registered.
"""
return _LAYER_REGISTERED.get(name, None)
return _LAYER_REGISTRY.get(name, None)
def disable_layer_logging():
......@@ -124,10 +125,18 @@ def layer_register(
isinstance(inputs[0], (tf.Tensor, tf.Variable)))):
raise ValueError("Invalid inputs to layer: " + str(inputs))
# TODO use inspect.getcallargs to enhance?
# update from current argument scope
# use kwargs from current argument scope
actual_args = copy.copy(get_arg_scope()[func.__name__])
# explicit kwargs overwrite argscope
actual_args.update(kwargs)
# explicit positional args also override argscope
if six.PY2:
posargmap = inspect.getcallargs(func, *args)
else:
posargmap = inspect.signature(func).bind_partial(*args).arguments
for k in six.iterkeys(posargmap):
if k in actual_args:
del actual_args[k]
if name is not None: # use scope
with tf.variable_scope(name) as scope:
......
......@@ -20,12 +20,17 @@ def map_arg(**maps):
Apply a mapping on certains argument before calling the original function.
Args:
maps (dict): {key: map_func}
maps (dict): {argument_name: map_func}
"""
def deco(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
argmap = inspect.getcallargs(func, *args, **kwargs)
if six.PY2:
argmap = inspect.getcallargs(func, *args, **kwargs)
else:
# getcallargs was deprecated since 3.5
sig = inspect.signature(func)
argmap = sig.bind_partial(*args, **kwargs).arguments
for k, map_func in six.iteritems(maps):
if k in argmap:
argmap[k] = map_func(argmap[k])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment