Commit 6686212e authored by Yuxin Wu's avatar Yuxin Wu

actually register the layer globally. fix #122.

parent e8c6cbf8
...@@ -80,7 +80,8 @@ Basically, a layer is a symbolic function with the following rules: ...@@ -80,7 +80,8 @@ Basically, a layer is a symbolic function with the following rules:
By making a symbolic function a "layer", the following thing will happen: By making a symbolic function a "layer", the following thing will happen:
+ You will call the function with a scope argument, e.g. `Conv2D('conv0', x, 32, 3)`. + You will call the function with a scope argument, e.g. `Conv2D('conv0', x, 32, 3)`.
Everything happening in this function will be under the variable scope 'conv0'. Everything happening in this function will be under the variable scope 'conv0'. You can register
the layer with `use_scope=False` to disable this feature.
+ Static shapes of input/output will be logged. + Static shapes of input/output will be logged.
+ It will then work with `argscope` to easily define default arguments. `argscope` will work for all + It will then work with `argscope` to easily define default arguments. `argscope` will work for all
the arguments except the input. the arguments except the input.
...@@ -88,3 +89,10 @@ By making a symbolic function a "layer", the following thing will happen: ...@@ -88,3 +89,10 @@ By making a symbolic function a "layer", the following thing will happen:
Take a look at the [Inception example](../examples/Inception/inception-bn.py#L36) to see how a complicated model can be described with these primitives. Take a look at the [Inception example](../examples/Inception/inception-bn.py#L36) to see how a complicated model can be described with these primitives.
There are also a number of symbolic functions in the `tfutils.symbolic_functions` module.
There isn't a rule about what kind of symbolic functions should be made a layer -- they're quite
similar anyway. But in general I define the following kinds of symbolic functions as layers:
+ Functions which contain variables. A variable scope is almost always helpful for such function.
+ Functions which are commonly referred to as "layers", such as pooling. This make a model
definition more straightforward.
...@@ -28,7 +28,6 @@ DEPTH = None ...@@ -28,7 +28,6 @@ DEPTH = None
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'), return [InputVar(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputVar(tf.int32, [None], 'label')] InputVar(tf.int32, [None], 'label')]
......
...@@ -3,13 +3,10 @@ ...@@ -3,13 +3,10 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules from pkgutil import iter_modules
from types import ModuleType
import six
import os import os
import os.path import os.path
# this line is necessary for _TFModuleFunc to work # this line is necessary for _TFModuleFunc to work
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
from ..utils import logger
__all__ = ['LinearWrap'] __all__ = ['LinearWrap']
...@@ -27,103 +24,3 @@ for _, module_name, _ in iter_modules( ...@@ -27,103 +24,3 @@ for _, module_name, _ in iter_modules(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
if not module_name.startswith('_'): if not module_name.startswith('_'):
_global_import(module_name) _global_import(module_name)
class LinearWrap(object):
""" A simple wrapper to easily create "linear" graph,
consisting of layers / symbolic functions with only one input & output.
"""
class _TFModuleFunc(object):
def __init__(self, mod, tensor):
self._mod = mod
self._t = tensor
def __getattr__(self, name):
ret = getattr(self._mod, name)
if isinstance(ret, ModuleType):
return LinearWrap._TFModuleFunc(ret, self._t)
else:
# assume to be a tf function
def f(*args, **kwargs):
o = ret(self._t, *args, **kwargs)
return LinearWrap(o)
return f
def __init__(self, tensor):
"""
Args:
tensor (tf.Tensor): the tensor to wrap
"""
self._t = tensor
def __getattr__(self, layer_name):
layer = eval(layer_name)
if hasattr(layer, 'f'):
# this is a registered tensorpack layer
# parse arguments by tensorpack model convention
if layer.use_scope:
def f(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret)
else:
def f(*args, **kwargs):
if len(args) and isinstance(args[0], six.string_types):
name, args = args[0], args[1:]
ret = layer(name, self._t, *args, **kwargs)
else:
ret = layer(self._t, *args, **kwargs)
return LinearWrap(ret)
return f
else:
if layer_name != 'tf':
logger.warn("You're calling LinearWrap.__getattr__ with something neither a layer nor 'tf'!")
assert isinstance(layer, ModuleType)
return LinearWrap._TFModuleFunc(layer, self._t)
def apply(self, func, *args, **kwargs):
"""
Apply a function on the wrapped tensor.
Returns:
LinearWrap: ``LinearWrap(func(self.tensor(), *args, **kwargs))``.
"""
ret = func(self._t, *args, **kwargs)
return LinearWrap(ret)
def apply2(self, func, *args, **kwargs):
"""
Apply a function on the wrapped tensor. The tensor
will be the second argument of func.
Returns:
LinearWrap: ``LinearWrap(func(args[0], self.tensor(), *args[1:], **kwargs))``.
"""
ret = func(args[0], self._t, *(args[1:]), **kwargs)
return LinearWrap(ret)
def __call__(self):
"""
Returns:
tf.Tensor: the underlying wrapped tensor.
"""
return self._t
def tensor(self):
"""
Equivalent to ``self.__call__()``.
Returns:
tf.Tensor: the underlying wrapped tensor.
"""
return self._t
def print_tensor(self):
"""
Print the underlying tensor and return self. Can be useful to get the
name of tensors inside :class:`LinearWrap`.
:return: self
"""
print(self._t)
return self
...@@ -13,9 +13,28 @@ from ..tfutils.summary import add_activation_summary ...@@ -13,9 +13,28 @@ from ..tfutils.summary import add_activation_summary
from ..utils import logger, building_rtfd from ..utils import logger, building_rtfd
# make sure each layer is only logged once # make sure each layer is only logged once
_layer_logged = set() _LAYER_LOGGED = set()
_LAYER_REGISTERED = {}
__all__ = ['layer_register', 'disable_layer_logging'] __all__ = ['layer_register', 'disable_layer_logging', 'get_registered_layer']
def _register(name, func):
if name in _LAYER_REGISTERED:
raise ValueError("Layer named {} is already registered!".format(name))
if name in ['tf']:
raise ValueError(logger.error("A layer cannot be named {}".format(name)))
_LAYER_REGISTERED[name] = func
def get_registered_layer(name):
"""
Args:
name (str): the name of the layer, e.g. 'Conv2D'
Returns:
the wrapped layer function, or None if not registered.
"""
return _LAYER_REGISTERED.get(name, None)
def disable_layer_logging(): def disable_layer_logging():
...@@ -27,7 +46,7 @@ def disable_layer_logging(): ...@@ -27,7 +46,7 @@ def disable_layer_logging():
def __contains__(self, x): def __contains__(self, x):
return True return True
# can use nonlocal in python3, but how # can use nonlocal in python3, but how
globals()['_layer_logged'] = ContainEverything() globals()['_LAYER_LOGGED'] = ContainEverything()
def layer_register( def layer_register(
...@@ -76,8 +95,8 @@ def layer_register( ...@@ -76,8 +95,8 @@ def layer_register(
if name is not None: if name is not None:
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
do_log_shape = log_shape and scope.name not in _layer_logged do_log_shape = log_shape and scope.name not in _LAYER_LOGGED
do_summary = do_summary and scope.name not in _layer_logged do_summary = do_summary and scope.name not in _LAYER_LOGGED
if do_log_shape: if do_log_shape:
logger.info("{} input: {}".format(scope.name, get_shape_str(inputs))) logger.info("{} input: {}".format(scope.name, get_shape_str(inputs)))
...@@ -88,7 +107,7 @@ def layer_register( ...@@ -88,7 +107,7 @@ def layer_register(
# log shape info and add activation # log shape info and add activation
logger.info("{} output: {}".format( logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs))) scope.name, get_shape_str(outputs)))
_layer_logged.add(scope.name) _LAYER_LOGGED.add(scope.name)
if do_summary: if do_summary:
if isinstance(outputs, list): if isinstance(outputs, list):
...@@ -101,8 +120,9 @@ def layer_register( ...@@ -101,8 +120,9 @@ def layer_register(
outputs = func(*args, **actual_args) outputs = func(*args, **actual_args)
return outputs return outputs
wrapped_func.f = func # attribute to access the underlining function object wrapped_func.f = func # attribute to access the underlying function object
wrapped_func.use_scope = use_scope wrapped_func.use_scope = use_scope
_register(func.__name__, wrapped_func)
return wrapped_func return wrapped_func
# need some special handling for sphinx to work with the arguments # need some special handling for sphinx to work with the arguments
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: linearwrap.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six
from types import ModuleType
from ..utils import logger
from .common import get_registered_layer
class LinearWrap(object):
""" A simple wrapper to easily create "linear" graph,
consisting of layers / symbolic functions with only one input & output.
"""
class _TFModuleFunc(object):
def __init__(self, mod, tensor):
self._mod = mod
self._t = tensor
def __getattr__(self, name):
ret = getattr(self._mod, name)
if isinstance(ret, ModuleType):
return LinearWrap._TFModuleFunc(ret, self._t)
else:
# assume to be a tf function
def f(*args, **kwargs):
o = ret(self._t, *args, **kwargs)
return LinearWrap(o)
return f
def __init__(self, tensor):
"""
Args:
tensor (tf.Tensor): the tensor to wrap
"""
self._t = tensor
def __getattr__(self, layer_name):
layer = get_registered_layer(layer_name)
if layer is not None:
# this is a registered tensorpack layer
# parse arguments by tensorpack model convention
if layer.use_scope:
def f(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret)
else:
def f(*args, **kwargs):
if len(args) and isinstance(args[0], six.string_types):
name, args = args[0], args[1:]
ret = layer(name, self._t, *args, **kwargs)
else:
ret = layer(self._t, *args, **kwargs)
return LinearWrap(ret)
return f
else:
if layer_name != 'tf':
logger.warn("You're calling LinearWrap.__getattr__ with something neither a layer nor 'tf'!")
assert isinstance(layer, ModuleType)
return LinearWrap._TFModuleFunc(layer, self._t)
def apply(self, func, *args, **kwargs):
"""
Apply a function on the wrapped tensor.
Returns:
LinearWrap: ``LinearWrap(func(self.tensor(), *args, **kwargs))``.
"""
ret = func(self._t, *args, **kwargs)
return LinearWrap(ret)
def apply2(self, func, *args, **kwargs):
"""
Apply a function on the wrapped tensor. The tensor
will be the second argument of func.
Returns:
LinearWrap: ``LinearWrap(func(args[0], self.tensor(), *args[1:], **kwargs))``.
"""
ret = func(args[0], self._t, *(args[1:]), **kwargs)
return LinearWrap(ret)
def __call__(self):
"""
Returns:
tf.Tensor: the underlying wrapped tensor.
"""
return self._t
def tensor(self):
"""
Equivalent to ``self.__call__()``.
Returns:
tf.Tensor: the underlying wrapped tensor.
"""
return self._t
def print_tensor(self):
"""
Print the underlying tensor and return self. Can be useful to get the
name of tensors inside :class:`LinearWrap`.
:return: self
"""
print(self._t)
return self
...@@ -11,7 +11,7 @@ from .batch_norm import BatchNorm ...@@ -11,7 +11,7 @@ from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU'] __all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
@layer_register() @layer_register(use_scope=False)
def Maxout(x, num_unit): def Maxout(x, num_unit):
""" """
Maxout as in the paper `Maxout Networks <http://arxiv.org/abs/1302.4389>`_. Maxout as in the paper `Maxout Networks <http://arxiv.org/abs/1302.4389>`_.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment