Commit ea60a630 authored by Yuxin Wu's avatar Yuxin Wu

Translate data_format and activation to tflayers (#627)

parent 0e5299bb
......@@ -8,6 +8,7 @@ from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from ..utils import logger
from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection
......@@ -67,7 +68,8 @@ def reshape_for_bn(param, ndims, chan, data_format):
@layer_register()
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True,
gamma_init=tf.constant_initializer(1.0), data_format='NHWC',
gamma_init=tf.constant_initializer(1.0),
data_format='channels_last',
internal_update=False):
"""
Batch Normalization layer, as described in the paper:
......@@ -109,6 +111,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
data_format = get_data_format(data_format, tfmode=False)
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
......@@ -181,7 +184,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
@layer_register()
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, gamma_init=None, data_format='NHWC'):
use_scale=True, use_bias=True, gamma_init=None,
data_format='channels_last'):
"""
Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
......@@ -210,18 +214,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
ndims = len(shape)
assert ndims in [2, 4]
if ndims == 2:
data_format = 'NHWC' # error using NCHW? (see #190)
data_format = 'channels_last' # error using NCHW? (see #190)
x = tf.reshape(x, [-1, 1, 1, shape[1]])
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchRenorm cannot have unknown channels!"
ctx = get_current_tower_context()
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
layer = tf.layers.BatchNormalization(
axis=1 if data_format == 'NCHW' else 3,
axis=1 if data_format == 'channels_first' else 3,
momentum=decay, epsilon=epsilon,
center=use_bias, scale=use_scale,
renorm=True,
......
......@@ -6,7 +6,7 @@
import tensorflow as tf
from .common import layer_register, VariableHolder, rename_get_variable
from ..tfutils.common import get_tf_version_number
from ..utils.argtools import shape2d, shape4d
from ..utils.argtools import shape2d, shape4d, get_data_format
__all__ = ['Conv2D', 'Deconv2D']
......@@ -15,8 +15,8 @@ __all__ = ['Conv2D', 'Deconv2D']
def Conv2D(x, out_channel, kernel_shape,
padding='SAME', stride=1,
W_init=None, b_init=None,
nl=tf.identity, split=1, use_bias=True,
data_format='NHWC', dilation_rate=1):
activation=tf.identity, split=1, use_bias=True,
data_format='channels_last', dilation_rate=1):
"""
2D convolution on 4D inputs.
......@@ -30,9 +30,7 @@ def Conv2D(x, out_channel, kernel_shape,
split (int): Split channels as used in Alexnet. Defaults to 1 (no split).
W_init: initializer for W. Defaults to `variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias.
data_format (str): 'NHWC' or 'NCHW'.
dilation_rate: (h, w) tuple or a int.
Returns:
......@@ -43,6 +41,7 @@ def Conv2D(x, out_channel, kernel_shape,
* ``W``: weights
* ``b``: bias
"""
data_format = get_data_format(data_format, tfmode=False)
in_shape = x.get_shape().as_list()
channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis]
......@@ -79,7 +78,7 @@ def Conv2D(x, out_channel, kernel_shape,
for i, k in zip(inputs, kernels)]
conv = tf.concat(outputs, channel_axis)
ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret.variables = VariableHolder(W=W)
if use_bias:
ret.variables.b = b
......@@ -90,8 +89,8 @@ def Conv2D(x, out_channel, kernel_shape,
def Deconv2D(x, out_channel, kernel_shape,
stride, padding='SAME',
W_init=None, b_init=None,
nl=tf.identity, use_bias=True,
data_format='NHWC'):
activation=tf.identity, use_bias=True,
data_format='channels_last'):
"""
2D deconvolution on 4D inputs.
......@@ -104,7 +103,6 @@ def Deconv2D(x, out_channel, kernel_shape,
padding (str): 'valid' or 'same'. Case insensitive.
W_init: initializer for W. Defaults to `tf.variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias.
Returns:
......@@ -115,13 +113,6 @@ def Deconv2D(x, out_channel, kernel_shape,
* ``W``: weights
* ``b``: bias
"""
in_shape = x.get_shape().as_list()
channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis]
assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"
assert isinstance(out_channel, int), out_channel
if W_init is None:
W_init = tf.variance_scaling_initializer(scale=2.0)
if b_init is None:
......@@ -131,8 +122,8 @@ def Deconv2D(x, out_channel, kernel_shape,
layer = tf.layers.Conv2DTranspose(
out_channel, kernel_shape,
strides=stride, padding=padding,
data_format='channels_last' if data_format == 'NHWC' else 'channels_first',
activation=lambda x: nl(x, name='output'),
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=W_init,
bias_initializer=b_init,
......@@ -142,4 +133,4 @@ def Deconv2D(x, out_channel, kernel_shape,
ret.variables = VariableHolder(W=layer.kernel)
if use_bias:
ret.variables.b = layer.bias
return ret
return tf.identity(ret, name='output')
......@@ -14,7 +14,7 @@ __all__ = ['FullyConnected']
@layer_register(log_shape=True)
def FullyConnected(x, out_dim,
W_init=None, b_init=None,
nl=tf.identity, use_bias=True):
activation=tf.identity, use_bias=True):
"""
Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
It is an equivalent of `tf.layers.dense` except for naming conventions.
......@@ -44,7 +44,7 @@ def FullyConnected(x, out_dim,
with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
layer = tf.layers.Dense(
out_dim, activation=lambda x: nl(x, name='output'), use_bias=use_bias,
out_dim, activation=activation, use_bias=use_bias,
kernel_initializer=W_init, bias_initializer=b_init,
trainable=True)
ret = layer.apply(x, scope=tf.get_variable_scope())
......@@ -52,4 +52,4 @@ def FullyConnected(x, out_dim,
ret.variables = VariableHolder(W=layer.kernel)
if use_bias:
ret.variables.b = layer.bias
return ret
return tf.identity(ret, name='output')
......@@ -5,6 +5,7 @@
import tensorflow as tf
from .common import layer_register, VariableHolder
from ..utils.argtools import get_data_format
__all__ = ['LayerNorm', 'InstanceNorm']
......@@ -13,7 +14,7 @@ __all__ = ['LayerNorm', 'InstanceNorm']
def LayerNorm(
x, epsilon=1e-5,
use_bias=True, use_scale=True,
gamma_init=None, data_format='NHWC'):
gamma_init=None, data_format='channels_last'):
"""
Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
......@@ -23,6 +24,7 @@ def LayerNorm(
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
"""
data_format = get_data_format(data_format, tfmode=False)
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
......@@ -62,7 +64,7 @@ def LayerNorm(
@layer_register()
def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='NHWC'):
def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'):
"""
Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization
......@@ -73,6 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation
"""
data_format = get_data_format(data_format, tfmode=False)
shape = x.get_shape().as_list()
assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
......
......@@ -7,7 +7,7 @@ import numpy as np
from .shape_utils import StaticDynamicShape
from .common import layer_register
from ..utils.argtools import shape2d
from ..utils.argtools import shape2d, get_data_format
from ._test import TestModel
......@@ -16,7 +16,7 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
@layer_register(log_shape=True)
def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
def MaxPooling(x, shape, stride=None, padding='VALID', data_format='channels_last'):
"""
Max Pooling on 4D tensors.
......@@ -31,13 +31,12 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
"""
if stride is None:
stride = shape
ret = tf.layers.max_pooling2d(x, shape, stride, padding,
'channels_last' if data_format == 'NHWC' else 'channels_first')
ret = tf.layers.max_pooling2d(x, shape, stride, padding, data_format=data_format)
return tf.identity(ret, name='output')
@layer_register(log_shape=True)
def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
def AvgPooling(x, shape, stride=None, padding='VALID', data_format='channels_last'):
"""
Average Pooling on 4D tensors.
......@@ -52,13 +51,12 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
"""
if stride is None:
stride = shape
ret = tf.layers.average_pooling2d(x, shape, stride, padding,
'channels_last' if data_format == 'NHWC' else 'channels_first')
ret = tf.layers.average_pooling2d(x, shape, stride, padding, data_format=data_format)
return tf.identity(ret, name='output')
@layer_register(log_shape=True)
def GlobalAvgPooling(x, data_format='NHWC'):
def GlobalAvgPooling(x, data_format='channels_last'):
"""
Global average pooling as in the paper `Network In Network
<http://arxiv.org/abs/1312.4400>`_.
......@@ -69,8 +67,7 @@ def GlobalAvgPooling(x, data_format='NHWC'):
tf.Tensor: a NC tensor named ``output``.
"""
assert x.shape.ndims == 4
assert data_format in ['NHWC', 'NCHW']
axis = [1, 2] if data_format == 'NHWC' else [2, 3]
axis = [1, 2] if data_format == 'channels_last' else [2, 3]
return tf.reduce_mean(x, axis, name='output')
......@@ -90,7 +87,7 @@ def UnPooling2x2ZeroFilled(x):
@layer_register(log_shape=True)
def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'):
def FixedUnPooling(x, shape, unpool_mat=None, data_format='channels_last'):
"""
Unpool the input with a fixed matrix to perform kronecker product with.
......@@ -103,6 +100,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'):
Returns:
tf.Tensor: a 4D image tensor.
"""
data_format = get_data_format(data_format, tfmode=False)
shape = shape2d(shape)
output_shape = StaticDynamicShape(x)
......
......@@ -11,6 +11,7 @@ import copy
from ..tfutils.argscope import get_arg_scope
from ..tfutils.model_utils import get_shape_str
from ..utils.argtools import get_data_format
from ..utils import logger
# make sure each layer is only logged once
......@@ -20,6 +21,18 @@ _LAYER_REGISTRY = {}
__all__ = ['layer_register']
def map_tfargs(kwargs):
df = kwargs.pop('data_format', None)
if df is not None:
df = get_data_format(df, tfmode=True)
kwargs['data_format'] = df
old_nl = kwargs.pop('nl', None)
if old_nl is not None:
kwargs['activation'] = lambda x, name=None: old_nl(x, name=name)
return kwargs
def _register(name, func):
if name in _LAYER_REGISTRY:
raise ValueError("Layer named {} is already registered!".format(name))
......@@ -113,6 +126,7 @@ def layer_register(
if k in actual_args:
del actual_args[k]
actual_args = map_tfargs(actual_args)
if name is not None: # use scope
with tf.variable_scope(name) as scope:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
......
......@@ -4,9 +4,7 @@
from contextlib import contextmanager
from collections import defaultdict
import inspect
import copy
import six
__all__ = ['argscope', 'get_arg_scope']
......@@ -35,14 +33,14 @@ def argscope(layers, **kwargs):
if not isinstance(layers, list):
layers = [layers]
def _check_args_exist(l):
args = inspect.getargspec(l).args
for k, v in six.iteritems(kwargs):
assert k in args, "No argument {} in {}".format(k, l.__name__)
# def _check_args_exist(l):
# args = inspect.getargspec(l).args
# for k, v in six.iteritems(kwargs):
# assert k in args, "No argument {} in {}".format(k, l.__name__)
for l in layers:
assert hasattr(l, 'symbolic_function'), "{} is not a registered layer".format(l.__name__)
_check_args_exist(l.symbolic_function)
# _check_args_exist(l.symbolic_function)
new_scope = copy.copy(get_arg_scope())
for l in layers:
......
......@@ -111,7 +111,18 @@ def shape2d(a):
raise RuntimeError("Illegal shape: {}".format(a))
def shape4d(a, data_format='NHWC'):
def get_data_format(data_format, tfmode=True):
if tfmode:
dic = {'NCHW': 'channels_first', 'NHWC': 'channels_last'}
else:
dic = {'channels_first': 'NCHW', 'channels_last': 'NHWC'}
ret = dic.get(data_format, data_format)
if ret not in dic.values():
raise ValueError("Unknown data_format: {}".format(data_format))
return ret
def shape4d(a, data_format='channels_last'):
"""
Ensuer a 4D shape, to use with 4D symbolic functions.
......@@ -123,7 +134,7 @@ def shape4d(a, data_format='NHWC'):
or ``[1, 1, a, a]`` depending on data_format.
"""
s2d = shape2d(a)
if data_format == 'NHWC':
if get_data_format(data_format) == 'channels_last':
return [1] + s2d + [1]
else:
return [1, 1] + s2d
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment