Commit 152c4c2c authored by Yuxin Wu's avatar Yuxin Wu

Reimplement Conv2DTranspose to avoid Keras bugs.

parent d0e410ad
......@@ -98,7 +98,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
data_format = get_data_format(data_format, tfmode=False)
data_format = get_data_format(data_format, keras_mode=False)
shape = inputs.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
......
......@@ -147,7 +147,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
this case.
"""
# parse shapes
data_format = get_data_format(data_format, tfmode=False)
data_format = get_data_format(data_format, keras_mode=False)
shape = inputs.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4], ndims
......
......@@ -80,7 +80,7 @@ def Conv2D(
else:
# group conv implementation
data_format = get_data_format(data_format, tfmode=False)
data_format = get_data_format(data_format, keras_mode=False)
in_shape = inputs.get_shape().as_list()
channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis]
......@@ -163,6 +163,7 @@ def Conv2DTranspose(
else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
if get_tf_version_tuple() <= (1, 12):
with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
layer = tf.layers.Conv2DTranspose(
filters,
......@@ -180,10 +181,53 @@ def Conv2DTranspose(
_reuse=tf.get_variable_scope().reuse)
ret = layer.apply(inputs, scope=tf.get_variable_scope())
ret = tf.identity(ret, name='output')
ret.variables = VariableHolder(W=layer.kernel)
if use_bias:
ret.variables.b = layer.bias
else:
# Our own implementation, to avoid Keras bugs. https://github.com/tensorflow/tensorflow/issues/25946
assert kernel_regularizer is None and bias_regularizer is None and activity_regularizer is None, \
"Unsupported arguments due to bug in TensorFlow 1.13"
data_format = get_data_format(data_format, keras_mode=False)
shape_dyn = tf.shape(inputs)
strides2d = shape2d(strides)
channels_in = inputs.shape[1 if data_format == 'NCHW' else 3]
if data_format == 'NCHW':
channels_in = inputs.shape[1]
out_shape_dyn = tf.stack(
[shape_dyn[0], filters,
shape_dyn[2] * strides2d[0],
shape_dyn[3] * strides2d[1]])
out_shape3_sta = [filters,
None if inputs.shape[2] is None else inputs.shape[2] * strides2d[0],
None if inputs.shape[3] is None else inputs.shape[3] * strides2d[1]]
else:
channels_in = inputs.shape[-1]
out_shape_dyn = tf.stack(
[shape_dyn[0],
shape_dyn[1] * strides2d[0],
shape_dyn[2] * strides2d[1],
channels_in])
out_shape3_sta = [None if inputs.shape[1] is None else inputs.shape[1] * strides2d[0],
None if inputs.shape[2] is None else inputs.shape[2] * strides2d[1],
filters]
kernel_shape = shape2d(kernel_size)
W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer)
if use_bias:
b = tf.get_variable('b', [filters], initializer=bias_initializer)
conv = tf.nn.conv2d_transpose(
inputs, W, out_shape_dyn,
shape4d(strides, data_format=data_format),
padding=padding.upper(),
data_format=data_format)
conv.set_shape(tf.TensorShape([None] + out_shape3_sta))
ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret.variables = VariableHolder(W=W)
if use_bias:
ret.variables.b = b
return ret
......
......@@ -24,7 +24,7 @@ def LayerNorm(
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
"""
data_format = get_data_format(data_format, tfmode=False)
data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
......@@ -75,7 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation
"""
data_format = get_data_format(data_format, tfmode=False)
data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list()
assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
......
......@@ -102,7 +102,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='channels_last'):
Returns:
tf.Tensor: a 4D image tensor.
"""
data_format = get_data_format(data_format, tfmode=False)
data_format = get_data_format(data_format, keras_mode=False)
shape = shape2d(shape)
output_shape = StaticDynamicShape(x)
......
......@@ -15,7 +15,7 @@ __all__ = []
def map_common_tfargs(kwargs):
df = kwargs.pop('data_format', None)
if df is not None:
df = get_data_format(df, tfmode=True)
df = get_data_format(df, keras_mode=True)
kwargs['data_format'] = df
old_nl = kwargs.pop('nl', None)
......
......@@ -340,7 +340,7 @@ class HorovodTrainer(SingleCostTrainer):
for Horovod installation and in the MPI command line.
See Horovod docs for details.
2. Due to a TF bug, you must not initialize CUDA context before the trainer starts training.
2. Due to a TF bug (#8136), you must not initialize CUDA context before the trainer starts training.
Therefore TF functions like `is_gpu_available()` or `list_local_devices()`
must be avoided.
......
......@@ -104,8 +104,8 @@ def shape2d(a):
raise RuntimeError("Illegal shape: {}".format(a))
def get_data_format(data_format, tfmode=True):
if tfmode:
def get_data_format(data_format, keras_mode=True):
if keras_mode:
dic = {'NCHW': 'channels_first', 'NHWC': 'channels_last'}
else:
dic = {'channels_first': 'NCHW', 'channels_last': 'NHWC'}
......@@ -115,7 +115,7 @@ def get_data_format(data_format, tfmode=True):
return ret
def shape4d(a, data_format='channels_last'):
def shape4d(a, data_format='NHWC'):
"""
Ensuer a 4D shape, to use with 4D symbolic functions.
......@@ -127,7 +127,7 @@ def shape4d(a, data_format='channels_last'):
or ``[1, 1, a, a]`` depending on data_format.
"""
s2d = shape2d(a)
if get_data_format(data_format) == 'channels_last':
if get_data_format(data_format, False) == 'NHWC':
return [1] + s2d + [1]
else:
return [1, 1] + s2d
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment