Commit 22c9c4d0 authored by Yuxin Wu's avatar Yuxin Wu

rename tflayer variables

parent 7800cf1c
# Symbolic Layers # Symbolic Layers
While you can use other symbolic libraries, Tensorpack contains a small collection of common model primitives,
tensorpack also contains a small collection of common model primitives,
such as conv/deconv, fc, bn, pooling layers. such as conv/deconv, fc, bn, pooling layers.
These layers were written only because there were no alternatives when
tensorpack was first developed.
Nowadays, these implementation actually call `tf.layers` directly.
Today, you can just use `tf.layers` or any other symbolic libraries inside tensorpack.
Using the tensorpack implementations, you can also benefit from `argscope` and `LinearWrap` to Using the tensorpack implementations, you can also benefit from `argscope` and `LinearWrap` to
simplify the code. simplify the code.
Note that these layers were written because there were no other alternatives back at that time. Note that to keep backward compatibility of code and pre-trained models, tensorpack layers
Now, these layers actually call `tf.layers` directly. have some small differences with `tf.layers`, including variable names and default options.
You can just use `tf.layers` as long as it fits your need. Refer to the API document for details.
### argscope and LinearWrap ### argscope and LinearWrap
`argscope` gives you a context with default arguments. `argscope` gives you a context with default arguments.
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
from .registry import layer_register # noqa from .registry import layer_register # noqa
from .utils import VariableHolder # noqa from .utils import VariableHolder # noqa
from .tflayer import rename_tflayer_get_variable
__all__ = ['layer_register', 'VariableHolder'] __all__ = ['layer_register', 'VariableHolder', 'rename_tflayer_get_variable']
...@@ -72,6 +72,9 @@ def rename_get_variable(mapping): ...@@ -72,6 +72,9 @@ def rename_get_variable(mapping):
""" """
Args: Args:
mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'} mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'}
Returns:
A context where the variables are renamed.
""" """
def custom_getter(getter, name, *args, **kwargs): def custom_getter(getter, name, *args, **kwargs):
splits = name.split('/') splits = name.split('/')
...@@ -84,6 +87,30 @@ def rename_get_variable(mapping): ...@@ -84,6 +87,30 @@ def rename_get_variable(mapping):
return custom_getter_scope(custom_getter) return custom_getter_scope(custom_getter)
def rename_tflayer_get_variable():
"""
Rename all :func:`tf.get_variable` with rules that transforms tflayer style to tensorpack style.
Returns:
A context where the variables are renamed.
Examples:
.. code-block:: python
with rename_tflayer_get_variable():
x = tf.layer.conv2d(input, 3, 3, name='conv0')
# variables will be named 'conv0/W', 'conv0/b'
"""
mapping = {
'kernel': 'W',
'bias': 'b',
'moving_mean': 'mean/EMA',
'moving_variance': 'variance/EMA',
}
return rename_get_variable(mapping)
def monkeypatch_tf_layers(): def monkeypatch_tf_layers():
if get_tf_version_number() < 1.4: if get_tf_version_number() < 1.4:
if not hasattr(tf.layers, 'Dense'): if not hasattr(tf.layers, 'Dense'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment