Commit 5567ebe1 authored by Yuxin Wu's avatar Yuxin Wu

Use tf.layers.Dense to implement FC (#291)

parent 9505edc6
......@@ -6,7 +6,6 @@
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from tensorflow.python.layers.normalization import BatchNorm as TF_BatchNorm
from ..utils import logger
from ..tfutils.tower import get_current_tower_context
......@@ -177,6 +176,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
<https://arxiv.org/abs/1702.03275>`_.
This implementation is a wrapper around `tf.layers.batch_normalization`.
Args:
x (tf.Tensor): a NHWC or NC tensor.
......@@ -210,7 +210,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
ctx = get_current_tower_context()
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
layer = TF_BatchNorm(
layer = tf.layers.BatchNormalization(
axis=1 if data_format == 'NCHW' else 3,
momentum=decay, epsilon=epsilon,
center=use_bias, scale=use_scale,
......
# -*- coding: UTF-8 -*-
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import inspect
from functools import wraps
import six
import re
import copy
from ..tfutils.argscope import get_arg_scope
from ..tfutils.model_utils import get_shape_str
from ..utils import logger
# make sure each layer is only logged once
_LAYER_LOGGED = set()
_LAYER_REGISTRY = {}
__all__ = ['layer_register']
class VariableHolder(object):
""" A proxy to access variables defined in a layer. """
def __init__(self, **kwargs):
"""
Args:
kwargs: {name:variable}
"""
self._vars = {}
for k, v in six.iteritems(kwargs):
self._add_variable(k, v)
def _add_variable(self, name, var):
assert name not in self._vars
self._vars[name] = var
def __setattr__(self, name, var):
if not name.startswith('_'):
self._add_variable(name, var)
else:
# private attributes
super(VariableHolder, self).__setattr__(name, var)
def __getattr__(self, name):
return self._vars[name]
def all(self):
"""
Returns:
list of all variables
"""
return list(six.itervalues(self._vars))
def _register(name, func):
if name in _LAYER_REGISTRY:
raise ValueError("Layer named {} is already registered!".format(name))
if name in ['tf']:
raise ValueError(logger.error("A layer cannot be named {}".format(name)))
_LAYER_REGISTRY[name] = func
def get_registered_layer(name):
"""
Args:
name (str): the name of the layer, e.g. 'Conv2D'
Returns:
the wrapped layer function, or None if not registered.
"""
return _LAYER_REGISTRY.get(name, None)
def disable_layer_logging():
"""
Disable the shape logging for all layers from this moment on. Can be
useful when creating multiple towers.
"""
class ContainEverything:
def __contains__(self, x):
return True
# can use nonlocal in python3, but how
globals()['_LAYER_LOGGED'] = ContainEverything()
def layer_register(
log_shape=False,
use_scope=True):
"""
Args:
log_shape (bool): log input/output shape of this layer
use_scope (bool or None):
Whether to call this layer with an extra first argument as scope.
When set to None, it can be called either with or without
the scope name argument.
It will try to figure out by checking if the first argument
is string or not.
Returns:
A decorator used to register a layer.
Examples:
.. code-block:: python
@layer_register(use_scope=True)
def add10(x):
return x + tf.get_variable('W', shape=[10])
"""
def wrapper(func):
@wraps(func)
def wrapped_func(*args, **kwargs):
assert args[0] is not None, args
if use_scope:
name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func
assert isinstance(name, six.string_types), name
else:
assert not log_shape
if isinstance(args[0], six.string_types):
if use_scope is False:
logger.warn(
"Please call layer {} without the first scope name argument, "
"or register the layer with use_scope=None to allow "
"two calling methods.".format(func.__name__))
name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func
else:
inputs = args[0]
name = None
if not (isinstance(inputs, (tf.Tensor, tf.Variable)) or
(isinstance(inputs, (list, tuple)) and
isinstance(inputs[0], (tf.Tensor, tf.Variable)))):
raise ValueError("Invalid inputs to layer: " + str(inputs))
# use kwargs from current argument scope
actual_args = copy.copy(get_arg_scope()[func.__name__])
# explicit kwargs overwrite argscope
actual_args.update(kwargs)
if six.PY3:
# explicit positional args also override argscope. only work in PY3
posargmap = inspect.signature(func).bind_partial(*args).arguments
for k in six.iterkeys(posargmap):
if k in actual_args:
del actual_args[k]
if name is not None: # use scope
with tf.variable_scope(name) as scope:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name = re.sub('tower[0-9]+/', '', scope.name)
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
if do_log_shape:
logger.info("{} input: {}".format(scope.name, get_shape_str(inputs)))
# run the actual function
outputs = func(*args, **actual_args)
if do_log_shape:
# log shape info and add activation
logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs)))
_LAYER_LOGGED.add(scope_name)
else:
# run the actual function
outputs = func(*args, **actual_args)
return outputs
wrapped_func.symbolic_function = func # attribute to access the underlying function object
wrapped_func.use_scope = use_scope
_register(func.__name__, wrapped_func)
return wrapped_func
return wrapper
from .registry import layer_register # noqa
from .utils import VariableHolder, rename_get_variable # noqa
......@@ -5,7 +5,7 @@
import tensorflow as tf
from .common import layer_register, VariableHolder
from .common import layer_register, rename_get_variable, VariableHolder
from ..tfutils import symbolic_functions as symbf
__all__ = ['FullyConnected']
......@@ -16,7 +16,8 @@ def FullyConnected(x, out_dim,
W_init=None, b_init=None,
nl=tf.identity, use_bias=True):
"""
Fully-Connected layer. Takes a N>1D tensor and returns a 2D tensor.
Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
It is an equivalent of `tf.layers.dense` except for naming conventions.
Args:
x (tf.Tensor): a tensor to be flattened except for the first dimension.
......@@ -35,21 +36,20 @@ def FullyConnected(x, out_dim,
* ``b``: bias
"""
x = symbf.batch_flatten(x)
in_dim = x.get_shape().as_list()[1]
if W_init is None:
W_init = tf.contrib.layers.variance_scaling_initializer()
if b_init is None:
b_init = tf.constant_initializer()
W = tf.get_variable('W', [in_dim, out_dim], initializer=W_init)
if use_bias:
b = tf.get_variable('b', [out_dim], initializer=b_init)
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
layer = tf.layers.Dense(
out_dim, activation=lambda x: nl(x, name='output'), use_bias=use_bias,
kernel_initializer=W_init, bias_initializer=b_init,
trainable=True)
ret = layer.apply(x, scope=tf.get_variable_scope())
ret = nl(prod, name='output')
ret.variables = VariableHolder(W=W)
ret.variables = VariableHolder(W=layer.kernel)
if use_bias:
ret.variables.b = b
ret.variables.b = layer.bias
return ret
......@@ -5,7 +5,7 @@
import six
from types import ModuleType
from .common import get_registered_layer
from .registry import get_registered_layer
__all__ = ['LinearWrap']
......
# -*- coding: UTF-8 -*-
# File: registry.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import inspect
from functools import wraps
import six
import re
import copy
from ..tfutils.argscope import get_arg_scope
from ..tfutils.model_utils import get_shape_str
from ..utils import logger
# make sure each layer is only logged once
_LAYER_LOGGED = set()
_LAYER_REGISTRY = {}
__all__ = ['layer_register']
def _register(name, func):
if name in _LAYER_REGISTRY:
raise ValueError("Layer named {} is already registered!".format(name))
if name in ['tf']:
raise ValueError(logger.error("A layer cannot be named {}".format(name)))
_LAYER_REGISTRY[name] = func
def get_registered_layer(name):
"""
Args:
name (str): the name of the layer, e.g. 'Conv2D'
Returns:
the wrapped layer function, or None if not registered.
"""
return _LAYER_REGISTRY.get(name, None)
def disable_layer_logging():
"""
Disable the shape logging for all layers from this moment on. Can be
useful when creating multiple towers.
"""
class ContainEverything:
def __contains__(self, x):
return True
# can use nonlocal in python3, but how
globals()['_LAYER_LOGGED'] = ContainEverything()
def layer_register(
log_shape=False,
use_scope=True):
"""
Args:
log_shape (bool): log input/output shape of this layer
use_scope (bool or None):
Whether to call this layer with an extra first argument as scope.
When set to None, it can be called either with or without
the scope name argument.
It will try to figure out by checking if the first argument
is string or not.
Returns:
A decorator used to register a layer.
Examples:
.. code-block:: python
@layer_register(use_scope=True)
def add10(x):
return x + tf.get_variable('W', shape=[10])
"""
def wrapper(func):
@wraps(func)
def wrapped_func(*args, **kwargs):
assert args[0] is not None, args
if use_scope:
name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func
assert isinstance(name, six.string_types), name
else:
assert not log_shape
if isinstance(args[0], six.string_types):
if use_scope is False:
logger.warn(
"Please call layer {} without the first scope name argument, "
"or register the layer with use_scope=None to allow "
"two calling methods.".format(func.__name__))
name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func
else:
inputs = args[0]
name = None
if not (isinstance(inputs, (tf.Tensor, tf.Variable)) or
(isinstance(inputs, (list, tuple)) and
isinstance(inputs[0], (tf.Tensor, tf.Variable)))):
raise ValueError("Invalid inputs to layer: " + str(inputs))
# use kwargs from current argument scope
actual_args = copy.copy(get_arg_scope()[func.__name__])
# explicit kwargs overwrite argscope
actual_args.update(kwargs)
if six.PY3:
# explicit positional args also override argscope. only work in PY3
posargmap = inspect.signature(func).bind_partial(*args).arguments
for k in six.iterkeys(posargmap):
if k in actual_args:
del actual_args[k]
if name is not None: # use scope
with tf.variable_scope(name) as scope:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name = re.sub('tower[0-9]+/', '', scope.name)
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
if do_log_shape:
logger.info("{} input: {}".format(scope.name, get_shape_str(inputs)))
# run the actual function
outputs = func(*args, **actual_args)
if do_log_shape:
# log shape info and add activation
logger.info("{} output: {}".format(
scope.name, get_shape_str(outputs)))
_LAYER_LOGGED.add(scope_name)
else:
# run the actual function
outputs = func(*args, **actual_args)
return outputs
wrapped_func.symbolic_function = func # attribute to access the underlying function object
wrapped_func.use_scope = use_scope
_register(func.__name__, wrapped_func)
return wrapped_func
return wrapper
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utils.py
import tensorflow as tf
from ..tfutils.varreplace import custom_getter_scope
from ..tfutils.common import get_tf_version_number
import six
class VariableHolder(object):
""" A proxy to access variables defined in a layer. """
def __init__(self, **kwargs):
"""
Args:
kwargs: {name:variable}
"""
self._vars = {}
for k, v in six.iteritems(kwargs):
self._add_variable(k, v)
def _add_variable(self, name, var):
assert name not in self._vars
self._vars[name] = var
def __setattr__(self, name, var):
if not name.startswith('_'):
self._add_variable(name, var)
else:
# private attributes
super(VariableHolder, self).__setattr__(name, var)
def __getattr__(self, name):
return self._vars[name]
def all(self):
"""
Returns:
list of all variables
"""
return list(six.itervalues(self._vars))
def rename_get_variable(mapping):
"""
Args:
mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'}
"""
def custom_getter(getter, name, *args, **kwargs):
splits = name.split('/')
basename = splits[-1]
if basename in mapping:
basename = mapping[basename]
splits[-1] = basename
name = '/'.join(splits)
return getter(name, *args, **kwargs)
return custom_getter_scope(custom_getter)
def monkeypatch_tf_layers():
if get_tf_version_number() < 1.4:
if not hasattr(tf.layers, 'Dense'):
from tensorflow.python.layers.core import Dense
tf.layers.Dense = Dense
from tensorflow.python.layers.normalization import BatchNormalization
tf.layers.BatchNormalization = BatchNormalization
monkeypatch_tf_layers()
......@@ -42,6 +42,12 @@ def remap_variables(fn):
Returns:
a context where all the variables will be mapped by fn.
Example:
.. code-block:: python
with varreplace.remap_variables(lambda var: quantize(var)):
x = FullyConnected('fc', x, 1000) # fc/{W,b} will be quantized
"""
def custom_getter(getter, *args, **kwargs):
v = getter(*args, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment