Commit 22c9c4d0 authored by Yuxin Wu's avatar Yuxin Wu

rename tflayer variables

parent 7800cf1c
# Symbolic Layers
While you can use other symbolic libraries,
tensorpack also contains a small collection of common model primitives,
Tensorpack contains a small collection of common model primitives,
such as conv/deconv, fc, bn, pooling layers.
These layers were written only because there were no alternatives when
tensorpack was first developed.
Nowadays, these implementation actually call `tf.layers` directly.
Today, you can just use `tf.layers` or any other symbolic libraries inside tensorpack.
Using the tensorpack implementations, you can also benefit from `argscope` and `LinearWrap` to
simplify the code.
Note that these layers were written because there were no other alternatives back at that time.
Now, these layers actually call `tf.layers` directly.
You can just use `tf.layers` as long as it fits your need.
Note that to keep backward compatibility of code and pre-trained models, tensorpack layers
have some small differences with `tf.layers`, including variable names and default options.
Refer to the API document for details.
### argscope and LinearWrap
`argscope` gives you a context with default arguments.
......
......@@ -3,5 +3,6 @@
from .registry import layer_register # noqa
from .utils import VariableHolder # noqa
from .tflayer import rename_tflayer_get_variable
__all__ = ['layer_register', 'VariableHolder']
__all__ = ['layer_register', 'VariableHolder', 'rename_tflayer_get_variable']
......@@ -72,6 +72,9 @@ def rename_get_variable(mapping):
"""
Args:
mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'}
Returns:
A context where the variables are renamed.
"""
def custom_getter(getter, name, *args, **kwargs):
splits = name.split('/')
......@@ -84,6 +87,30 @@ def rename_get_variable(mapping):
return custom_getter_scope(custom_getter)
def rename_tflayer_get_variable():
"""
Rename all :func:`tf.get_variable` with rules that transforms tflayer style to tensorpack style.
Returns:
A context where the variables are renamed.
Examples:
.. code-block:: python
with rename_tflayer_get_variable():
x = tf.layer.conv2d(input, 3, 3, name='conv0')
# variables will be named 'conv0/W', 'conv0/b'
"""
mapping = {
'kernel': 'W',
'bias': 'b',
'moving_mean': 'mean/EMA',
'moving_variance': 'variance/EMA',
}
return rename_get_variable(mapping)
def monkeypatch_tf_layers():
if get_tf_version_number() < 1.4:
if not hasattr(tf.layers, 'Dense'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment