Commit ea60a630 authored by Yuxin Wu's avatar Yuxin Wu

Translate data_format and activation to tflayers (#627)

parent 0e5299bb
...@@ -8,6 +8,7 @@ from tensorflow.contrib.framework import add_model_variable ...@@ -8,6 +8,7 @@ from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
...@@ -67,7 +68,8 @@ def reshape_for_bn(param, ndims, chan, data_format): ...@@ -67,7 +68,8 @@ def reshape_for_bn(param, ndims, chan, data_format):
@layer_register() @layer_register()
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, use_scale=True, use_bias=True,
gamma_init=tf.constant_initializer(1.0), data_format='NHWC', gamma_init=tf.constant_initializer(1.0),
data_format='channels_last',
internal_update=False): internal_update=False):
""" """
Batch Normalization layer, as described in the paper: Batch Normalization layer, as described in the paper:
...@@ -109,6 +111,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -109,6 +111,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in don't want to fine tune the EMA. EMA will not be updated in
this case. this case.
""" """
data_format = get_data_format(data_format, tfmode=False)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4] assert ndims in [2, 4]
...@@ -181,7 +184,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -181,7 +184,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
@layer_register() @layer_register()
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, gamma_init=None, data_format='NHWC'): use_scale=True, use_bias=True, gamma_init=None,
data_format='channels_last'):
""" """
Batch Renormalization layer, as described in the paper: Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models `Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
...@@ -210,18 +214,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -210,18 +214,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4] assert ndims in [2, 4]
if ndims == 2: if ndims == 2:
data_format = 'NHWC' # error using NCHW? (see #190) data_format = 'channels_last' # error using NCHW? (see #190)
x = tf.reshape(x, [-1, 1, 1, shape[1]]) x = tf.reshape(x, [-1, 1, 1, shape[1]])
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchRenorm cannot have unknown channels!"
ctx = get_current_tower_context() ctx = get_current_tower_context()
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
layer = tf.layers.BatchNormalization( layer = tf.layers.BatchNormalization(
axis=1 if data_format == 'NCHW' else 3, axis=1 if data_format == 'channels_first' else 3,
momentum=decay, epsilon=epsilon, momentum=decay, epsilon=epsilon,
center=use_bias, scale=use_scale, center=use_bias, scale=use_scale,
renorm=True, renorm=True,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register, VariableHolder, rename_get_variable from .common import layer_register, VariableHolder, rename_get_variable
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
from ..utils.argtools import shape2d, shape4d from ..utils.argtools import shape2d, shape4d, get_data_format
__all__ = ['Conv2D', 'Deconv2D'] __all__ = ['Conv2D', 'Deconv2D']
...@@ -15,8 +15,8 @@ __all__ = ['Conv2D', 'Deconv2D'] ...@@ -15,8 +15,8 @@ __all__ = ['Conv2D', 'Deconv2D']
def Conv2D(x, out_channel, kernel_shape, def Conv2D(x, out_channel, kernel_shape,
padding='SAME', stride=1, padding='SAME', stride=1,
W_init=None, b_init=None, W_init=None, b_init=None,
nl=tf.identity, split=1, use_bias=True, activation=tf.identity, split=1, use_bias=True,
data_format='NHWC', dilation_rate=1): data_format='channels_last', dilation_rate=1):
""" """
2D convolution on 4D inputs. 2D convolution on 4D inputs.
...@@ -30,9 +30,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -30,9 +30,7 @@ def Conv2D(x, out_channel, kernel_shape,
split (int): Split channels as used in Alexnet. Defaults to 1 (no split). split (int): Split channels as used in Alexnet. Defaults to 1 (no split).
W_init: initializer for W. Defaults to `variance_scaling_initializer(2.0)`, i.e. kaiming-normal. W_init: initializer for W. Defaults to `variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero. b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias. use_bias (bool): whether to use bias.
data_format (str): 'NHWC' or 'NCHW'.
dilation_rate: (h, w) tuple or a int. dilation_rate: (h, w) tuple or a int.
Returns: Returns:
...@@ -43,6 +41,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -43,6 +41,7 @@ def Conv2D(x, out_channel, kernel_shape,
* ``W``: weights * ``W``: weights
* ``b``: bias * ``b``: bias
""" """
data_format = get_data_format(data_format, tfmode=False)
in_shape = x.get_shape().as_list() in_shape = x.get_shape().as_list()
channel_axis = 3 if data_format == 'NHWC' else 1 channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis] in_channel = in_shape[channel_axis]
...@@ -79,7 +78,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -79,7 +78,7 @@ def Conv2D(x, out_channel, kernel_shape,
for i, k in zip(inputs, kernels)] for i, k in zip(inputs, kernels)]
conv = tf.concat(outputs, channel_axis) conv = tf.concat(outputs, channel_axis)
ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output') ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret.variables = VariableHolder(W=W) ret.variables = VariableHolder(W=W)
if use_bias: if use_bias:
ret.variables.b = b ret.variables.b = b
...@@ -90,8 +89,8 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -90,8 +89,8 @@ def Conv2D(x, out_channel, kernel_shape,
def Deconv2D(x, out_channel, kernel_shape, def Deconv2D(x, out_channel, kernel_shape,
stride, padding='SAME', stride, padding='SAME',
W_init=None, b_init=None, W_init=None, b_init=None,
nl=tf.identity, use_bias=True, activation=tf.identity, use_bias=True,
data_format='NHWC'): data_format='channels_last'):
""" """
2D deconvolution on 4D inputs. 2D deconvolution on 4D inputs.
...@@ -104,7 +103,6 @@ def Deconv2D(x, out_channel, kernel_shape, ...@@ -104,7 +103,6 @@ def Deconv2D(x, out_channel, kernel_shape,
padding (str): 'valid' or 'same'. Case insensitive. padding (str): 'valid' or 'same'. Case insensitive.
W_init: initializer for W. Defaults to `tf.variance_scaling_initializer(2.0)`, i.e. kaiming-normal. W_init: initializer for W. Defaults to `tf.variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero. b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias. use_bias (bool): whether to use bias.
Returns: Returns:
...@@ -115,13 +113,6 @@ def Deconv2D(x, out_channel, kernel_shape, ...@@ -115,13 +113,6 @@ def Deconv2D(x, out_channel, kernel_shape,
* ``W``: weights * ``W``: weights
* ``b``: bias * ``b``: bias
""" """
in_shape = x.get_shape().as_list()
channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis]
assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"
assert isinstance(out_channel, int), out_channel
if W_init is None: if W_init is None:
W_init = tf.variance_scaling_initializer(scale=2.0) W_init = tf.variance_scaling_initializer(scale=2.0)
if b_init is None: if b_init is None:
...@@ -131,8 +122,8 @@ def Deconv2D(x, out_channel, kernel_shape, ...@@ -131,8 +122,8 @@ def Deconv2D(x, out_channel, kernel_shape,
layer = tf.layers.Conv2DTranspose( layer = tf.layers.Conv2DTranspose(
out_channel, kernel_shape, out_channel, kernel_shape,
strides=stride, padding=padding, strides=stride, padding=padding,
data_format='channels_last' if data_format == 'NHWC' else 'channels_first', data_format=data_format,
activation=lambda x: nl(x, name='output'), activation=activation,
use_bias=use_bias, use_bias=use_bias,
kernel_initializer=W_init, kernel_initializer=W_init,
bias_initializer=b_init, bias_initializer=b_init,
...@@ -142,4 +133,4 @@ def Deconv2D(x, out_channel, kernel_shape, ...@@ -142,4 +133,4 @@ def Deconv2D(x, out_channel, kernel_shape,
ret.variables = VariableHolder(W=layer.kernel) ret.variables = VariableHolder(W=layer.kernel)
if use_bias: if use_bias:
ret.variables.b = layer.bias ret.variables.b = layer.bias
return ret return tf.identity(ret, name='output')
...@@ -14,7 +14,7 @@ __all__ = ['FullyConnected'] ...@@ -14,7 +14,7 @@ __all__ = ['FullyConnected']
@layer_register(log_shape=True) @layer_register(log_shape=True)
def FullyConnected(x, out_dim, def FullyConnected(x, out_dim,
W_init=None, b_init=None, W_init=None, b_init=None,
nl=tf.identity, use_bias=True): activation=tf.identity, use_bias=True):
""" """
Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor. Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
It is an equivalent of `tf.layers.dense` except for naming conventions. It is an equivalent of `tf.layers.dense` except for naming conventions.
...@@ -44,7 +44,7 @@ def FullyConnected(x, out_dim, ...@@ -44,7 +44,7 @@ def FullyConnected(x, out_dim,
with rename_get_variable({'kernel': 'W', 'bias': 'b'}): with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
layer = tf.layers.Dense( layer = tf.layers.Dense(
out_dim, activation=lambda x: nl(x, name='output'), use_bias=use_bias, out_dim, activation=activation, use_bias=use_bias,
kernel_initializer=W_init, bias_initializer=b_init, kernel_initializer=W_init, bias_initializer=b_init,
trainable=True) trainable=True)
ret = layer.apply(x, scope=tf.get_variable_scope()) ret = layer.apply(x, scope=tf.get_variable_scope())
...@@ -52,4 +52,4 @@ def FullyConnected(x, out_dim, ...@@ -52,4 +52,4 @@ def FullyConnected(x, out_dim,
ret.variables = VariableHolder(W=layer.kernel) ret.variables = VariableHolder(W=layer.kernel)
if use_bias: if use_bias:
ret.variables.b = layer.bias ret.variables.b = layer.bias
return ret return tf.identity(ret, name='output')
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
from ..utils.argtools import get_data_format
__all__ = ['LayerNorm', 'InstanceNorm'] __all__ = ['LayerNorm', 'InstanceNorm']
...@@ -13,7 +14,7 @@ __all__ = ['LayerNorm', 'InstanceNorm'] ...@@ -13,7 +14,7 @@ __all__ = ['LayerNorm', 'InstanceNorm']
def LayerNorm( def LayerNorm(
x, epsilon=1e-5, x, epsilon=1e-5,
use_bias=True, use_scale=True, use_bias=True, use_scale=True,
gamma_init=None, data_format='NHWC'): gamma_init=None, data_format='channels_last'):
""" """
Layer Normalization layer, as described in the paper: Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_. `Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
...@@ -23,6 +24,7 @@ def LayerNorm( ...@@ -23,6 +24,7 @@ def LayerNorm(
epsilon (float): epsilon to avoid divide-by-zero. epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not. use_scale, use_bias (bool): whether to use the extra affine transformation or not.
""" """
data_format = get_data_format(data_format, tfmode=False)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4] assert ndims in [2, 4]
...@@ -62,7 +64,7 @@ def LayerNorm( ...@@ -62,7 +64,7 @@ def LayerNorm(
@layer_register() @layer_register()
def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='NHWC'): def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'):
""" """
Instance Normalization, as in the paper: Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization `Instance Normalization: The Missing Ingredient for Fast Stylization
...@@ -73,6 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format= ...@@ -73,6 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
epsilon (float): avoid divide-by-zero epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation use_affine (bool): whether to apply learnable affine transformation
""" """
data_format = get_data_format(data_format, tfmode=False)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) == 4, "Input of InstanceNorm has to be 4D!" assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from .shape_utils import StaticDynamicShape from .shape_utils import StaticDynamicShape
from .common import layer_register from .common import layer_register
from ..utils.argtools import shape2d from ..utils.argtools import shape2d, get_data_format
from ._test import TestModel from ._test import TestModel
...@@ -16,7 +16,7 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling', ...@@ -16,7 +16,7 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
@layer_register(log_shape=True) @layer_register(log_shape=True)
def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'): def MaxPooling(x, shape, stride=None, padding='VALID', data_format='channels_last'):
""" """
Max Pooling on 4D tensors. Max Pooling on 4D tensors.
...@@ -31,13 +31,12 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'): ...@@ -31,13 +31,12 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
""" """
if stride is None: if stride is None:
stride = shape stride = shape
ret = tf.layers.max_pooling2d(x, shape, stride, padding, ret = tf.layers.max_pooling2d(x, shape, stride, padding, data_format=data_format)
'channels_last' if data_format == 'NHWC' else 'channels_first')
return tf.identity(ret, name='output') return tf.identity(ret, name='output')
@layer_register(log_shape=True) @layer_register(log_shape=True)
def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'): def AvgPooling(x, shape, stride=None, padding='VALID', data_format='channels_last'):
""" """
Average Pooling on 4D tensors. Average Pooling on 4D tensors.
...@@ -52,13 +51,12 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'): ...@@ -52,13 +51,12 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
""" """
if stride is None: if stride is None:
stride = shape stride = shape
ret = tf.layers.average_pooling2d(x, shape, stride, padding, ret = tf.layers.average_pooling2d(x, shape, stride, padding, data_format=data_format)
'channels_last' if data_format == 'NHWC' else 'channels_first')
return tf.identity(ret, name='output') return tf.identity(ret, name='output')
@layer_register(log_shape=True) @layer_register(log_shape=True)
def GlobalAvgPooling(x, data_format='NHWC'): def GlobalAvgPooling(x, data_format='channels_last'):
""" """
Global average pooling as in the paper `Network In Network Global average pooling as in the paper `Network In Network
<http://arxiv.org/abs/1312.4400>`_. <http://arxiv.org/abs/1312.4400>`_.
...@@ -69,8 +67,7 @@ def GlobalAvgPooling(x, data_format='NHWC'): ...@@ -69,8 +67,7 @@ def GlobalAvgPooling(x, data_format='NHWC'):
tf.Tensor: a NC tensor named ``output``. tf.Tensor: a NC tensor named ``output``.
""" """
assert x.shape.ndims == 4 assert x.shape.ndims == 4
assert data_format in ['NHWC', 'NCHW'] axis = [1, 2] if data_format == 'channels_last' else [2, 3]
axis = [1, 2] if data_format == 'NHWC' else [2, 3]
return tf.reduce_mean(x, axis, name='output') return tf.reduce_mean(x, axis, name='output')
...@@ -90,7 +87,7 @@ def UnPooling2x2ZeroFilled(x): ...@@ -90,7 +87,7 @@ def UnPooling2x2ZeroFilled(x):
@layer_register(log_shape=True) @layer_register(log_shape=True)
def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'): def FixedUnPooling(x, shape, unpool_mat=None, data_format='channels_last'):
""" """
Unpool the input with a fixed matrix to perform kronecker product with. Unpool the input with a fixed matrix to perform kronecker product with.
...@@ -103,6 +100,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'): ...@@ -103,6 +100,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'):
Returns: Returns:
tf.Tensor: a 4D image tensor. tf.Tensor: a 4D image tensor.
""" """
data_format = get_data_format(data_format, tfmode=False)
shape = shape2d(shape) shape = shape2d(shape)
output_shape = StaticDynamicShape(x) output_shape = StaticDynamicShape(x)
......
...@@ -11,6 +11,7 @@ import copy ...@@ -11,6 +11,7 @@ import copy
from ..tfutils.argscope import get_arg_scope from ..tfutils.argscope import get_arg_scope
from ..tfutils.model_utils import get_shape_str from ..tfutils.model_utils import get_shape_str
from ..utils.argtools import get_data_format
from ..utils import logger from ..utils import logger
# make sure each layer is only logged once # make sure each layer is only logged once
...@@ -20,6 +21,18 @@ _LAYER_REGISTRY = {} ...@@ -20,6 +21,18 @@ _LAYER_REGISTRY = {}
__all__ = ['layer_register'] __all__ = ['layer_register']
def map_tfargs(kwargs):
df = kwargs.pop('data_format', None)
if df is not None:
df = get_data_format(df, tfmode=True)
kwargs['data_format'] = df
old_nl = kwargs.pop('nl', None)
if old_nl is not None:
kwargs['activation'] = lambda x, name=None: old_nl(x, name=name)
return kwargs
def _register(name, func): def _register(name, func):
if name in _LAYER_REGISTRY: if name in _LAYER_REGISTRY:
raise ValueError("Layer named {} is already registered!".format(name)) raise ValueError("Layer named {} is already registered!".format(name))
...@@ -113,6 +126,7 @@ def layer_register( ...@@ -113,6 +126,7 @@ def layer_register(
if k in actual_args: if k in actual_args:
del actual_args[k] del actual_args[k]
actual_args = map_tfargs(actual_args)
if name is not None: # use scope if name is not None: # use scope
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
# this name is only used to surpress logging, doesn't hurt to do some heuristics # this name is only used to surpress logging, doesn't hurt to do some heuristics
......
...@@ -4,9 +4,7 @@ ...@@ -4,9 +4,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from collections import defaultdict from collections import defaultdict
import inspect
import copy import copy
import six
__all__ = ['argscope', 'get_arg_scope'] __all__ = ['argscope', 'get_arg_scope']
...@@ -35,14 +33,14 @@ def argscope(layers, **kwargs): ...@@ -35,14 +33,14 @@ def argscope(layers, **kwargs):
if not isinstance(layers, list): if not isinstance(layers, list):
layers = [layers] layers = [layers]
def _check_args_exist(l): # def _check_args_exist(l):
args = inspect.getargspec(l).args # args = inspect.getargspec(l).args
for k, v in six.iteritems(kwargs): # for k, v in six.iteritems(kwargs):
assert k in args, "No argument {} in {}".format(k, l.__name__) # assert k in args, "No argument {} in {}".format(k, l.__name__)
for l in layers: for l in layers:
assert hasattr(l, 'symbolic_function'), "{} is not a registered layer".format(l.__name__) assert hasattr(l, 'symbolic_function'), "{} is not a registered layer".format(l.__name__)
_check_args_exist(l.symbolic_function) # _check_args_exist(l.symbolic_function)
new_scope = copy.copy(get_arg_scope()) new_scope = copy.copy(get_arg_scope())
for l in layers: for l in layers:
......
...@@ -111,7 +111,18 @@ def shape2d(a): ...@@ -111,7 +111,18 @@ def shape2d(a):
raise RuntimeError("Illegal shape: {}".format(a)) raise RuntimeError("Illegal shape: {}".format(a))
def shape4d(a, data_format='NHWC'): def get_data_format(data_format, tfmode=True):
if tfmode:
dic = {'NCHW': 'channels_first', 'NHWC': 'channels_last'}
else:
dic = {'channels_first': 'NCHW', 'channels_last': 'NHWC'}
ret = dic.get(data_format, data_format)
if ret not in dic.values():
raise ValueError("Unknown data_format: {}".format(data_format))
return ret
def shape4d(a, data_format='channels_last'):
""" """
Ensuer a 4D shape, to use with 4D symbolic functions. Ensuer a 4D shape, to use with 4D symbolic functions.
...@@ -123,7 +134,7 @@ def shape4d(a, data_format='NHWC'): ...@@ -123,7 +134,7 @@ def shape4d(a, data_format='NHWC'):
or ``[1, 1, a, a]`` depending on data_format. or ``[1, 1, a, a]`` depending on data_format.
""" """
s2d = shape2d(a) s2d = shape2d(a)
if data_format == 'NHWC': if get_data_format(data_format) == 'channels_last':
return [1] + s2d + [1] return [1] + s2d + [1]
else: else:
return [1, 1] + s2d return [1, 1] + s2d
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment