Commit 152c4c2c authored by Yuxin Wu's avatar Yuxin Wu

Reimplement Conv2DTranspose to avoid Keras bugs.

parent d0e410ad
...@@ -98,7 +98,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -98,7 +98,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in don't want to fine tune the EMA. EMA will not be updated in
this case. this case.
""" """
data_format = get_data_format(data_format, tfmode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4] assert ndims in [2, 4]
......
...@@ -147,7 +147,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -147,7 +147,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
this case. this case.
""" """
# parse shapes # parse shapes
data_format = get_data_format(data_format, tfmode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4], ndims assert ndims in [2, 4], ndims
......
...@@ -80,7 +80,7 @@ def Conv2D( ...@@ -80,7 +80,7 @@ def Conv2D(
else: else:
# group conv implementation # group conv implementation
data_format = get_data_format(data_format, tfmode=False) data_format = get_data_format(data_format, keras_mode=False)
in_shape = inputs.get_shape().as_list() in_shape = inputs.get_shape().as_list()
channel_axis = 3 if data_format == 'NHWC' else 1 channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis] in_channel = in_shape[channel_axis]
...@@ -163,27 +163,71 @@ def Conv2DTranspose( ...@@ -163,27 +163,71 @@ def Conv2DTranspose(
else: else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
with rename_get_variable({'kernel': 'W', 'bias': 'b'}): if get_tf_version_tuple() <= (1, 12):
layer = tf.layers.Conv2DTranspose( with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
filters, layer = tf.layers.Conv2DTranspose(
kernel_size, filters,
strides=strides, kernel_size,
padding=padding, strides=strides,
data_format=data_format, padding=padding,
activation=activation, data_format=data_format,
use_bias=use_bias, activation=activation,
kernel_initializer=kernel_initializer, use_bias=use_bias,
bias_initializer=bias_initializer, kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer, bias_initializer=bias_initializer,
bias_regularizer=bias_regularizer, kernel_regularizer=kernel_regularizer,
activity_regularizer=activity_regularizer, bias_regularizer=bias_regularizer,
_reuse=tf.get_variable_scope().reuse) activity_regularizer=activity_regularizer,
ret = layer.apply(inputs, scope=tf.get_variable_scope()) _reuse=tf.get_variable_scope().reuse)
ret = tf.identity(ret, name='output') ret = layer.apply(inputs, scope=tf.get_variable_scope())
ret = tf.identity(ret, name='output')
ret.variables = VariableHolder(W=layer.kernel) ret.variables = VariableHolder(W=layer.kernel)
if use_bias: if use_bias:
ret.variables.b = layer.bias ret.variables.b = layer.bias
else:
# Our own implementation, to avoid Keras bugs. https://github.com/tensorflow/tensorflow/issues/25946
assert kernel_regularizer is None and bias_regularizer is None and activity_regularizer is None, \
"Unsupported arguments due to bug in TensorFlow 1.13"
data_format = get_data_format(data_format, keras_mode=False)
shape_dyn = tf.shape(inputs)
strides2d = shape2d(strides)
channels_in = inputs.shape[1 if data_format == 'NCHW' else 3]
if data_format == 'NCHW':
channels_in = inputs.shape[1]
out_shape_dyn = tf.stack(
[shape_dyn[0], filters,
shape_dyn[2] * strides2d[0],
shape_dyn[3] * strides2d[1]])
out_shape3_sta = [filters,
None if inputs.shape[2] is None else inputs.shape[2] * strides2d[0],
None if inputs.shape[3] is None else inputs.shape[3] * strides2d[1]]
else:
channels_in = inputs.shape[-1]
out_shape_dyn = tf.stack(
[shape_dyn[0],
shape_dyn[1] * strides2d[0],
shape_dyn[2] * strides2d[1],
channels_in])
out_shape3_sta = [None if inputs.shape[1] is None else inputs.shape[1] * strides2d[0],
None if inputs.shape[2] is None else inputs.shape[2] * strides2d[1],
filters]
kernel_shape = shape2d(kernel_size)
W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer)
if use_bias:
b = tf.get_variable('b', [filters], initializer=bias_initializer)
conv = tf.nn.conv2d_transpose(
inputs, W, out_shape_dyn,
shape4d(strides, data_format=data_format),
padding=padding.upper(),
data_format=data_format)
conv.set_shape(tf.TensorShape([None] + out_shape3_sta))
ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret.variables = VariableHolder(W=W)
if use_bias:
ret.variables.b = b
return ret return ret
......
...@@ -24,7 +24,7 @@ def LayerNorm( ...@@ -24,7 +24,7 @@ def LayerNorm(
epsilon (float): epsilon to avoid divide-by-zero. epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not. use_scale, use_bias (bool): whether to use the extra affine transformation or not.
""" """
data_format = get_data_format(data_format, tfmode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4] assert ndims in [2, 4]
...@@ -75,7 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format= ...@@ -75,7 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
epsilon (float): avoid divide-by-zero epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation use_affine (bool): whether to apply learnable affine transformation
""" """
data_format = get_data_format(data_format, tfmode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) == 4, "Input of InstanceNorm has to be 4D!" assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
......
...@@ -102,7 +102,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='channels_last'): ...@@ -102,7 +102,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='channels_last'):
Returns: Returns:
tf.Tensor: a 4D image tensor. tf.Tensor: a 4D image tensor.
""" """
data_format = get_data_format(data_format, tfmode=False) data_format = get_data_format(data_format, keras_mode=False)
shape = shape2d(shape) shape = shape2d(shape)
output_shape = StaticDynamicShape(x) output_shape = StaticDynamicShape(x)
......
...@@ -15,7 +15,7 @@ __all__ = [] ...@@ -15,7 +15,7 @@ __all__ = []
def map_common_tfargs(kwargs): def map_common_tfargs(kwargs):
df = kwargs.pop('data_format', None) df = kwargs.pop('data_format', None)
if df is not None: if df is not None:
df = get_data_format(df, tfmode=True) df = get_data_format(df, keras_mode=True)
kwargs['data_format'] = df kwargs['data_format'] = df
old_nl = kwargs.pop('nl', None) old_nl = kwargs.pop('nl', None)
......
...@@ -340,7 +340,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -340,7 +340,7 @@ class HorovodTrainer(SingleCostTrainer):
for Horovod installation and in the MPI command line. for Horovod installation and in the MPI command line.
See Horovod docs for details. See Horovod docs for details.
2. Due to a TF bug, you must not initialize CUDA context before the trainer starts training. 2. Due to a TF bug (#8136), you must not initialize CUDA context before the trainer starts training.
Therefore TF functions like `is_gpu_available()` or `list_local_devices()` Therefore TF functions like `is_gpu_available()` or `list_local_devices()`
must be avoided. must be avoided.
......
...@@ -104,8 +104,8 @@ def shape2d(a): ...@@ -104,8 +104,8 @@ def shape2d(a):
raise RuntimeError("Illegal shape: {}".format(a)) raise RuntimeError("Illegal shape: {}".format(a))
def get_data_format(data_format, tfmode=True): def get_data_format(data_format, keras_mode=True):
if tfmode: if keras_mode:
dic = {'NCHW': 'channels_first', 'NHWC': 'channels_last'} dic = {'NCHW': 'channels_first', 'NHWC': 'channels_last'}
else: else:
dic = {'channels_first': 'NCHW', 'channels_last': 'NHWC'} dic = {'channels_first': 'NCHW', 'channels_last': 'NHWC'}
...@@ -115,7 +115,7 @@ def get_data_format(data_format, tfmode=True): ...@@ -115,7 +115,7 @@ def get_data_format(data_format, tfmode=True):
return ret return ret
def shape4d(a, data_format='channels_last'): def shape4d(a, data_format='NHWC'):
""" """
Ensuer a 4D shape, to use with 4D symbolic functions. Ensuer a 4D shape, to use with 4D symbolic functions.
...@@ -127,7 +127,7 @@ def shape4d(a, data_format='channels_last'): ...@@ -127,7 +127,7 @@ def shape4d(a, data_format='channels_last'):
or ``[1, 1, a, a]`` depending on data_format. or ``[1, 1, a, a]`` depending on data_format.
""" """
s2d = shape2d(a) s2d = shape2d(a)
if get_data_format(data_format) == 'channels_last': if get_data_format(data_format, False) == 'NHWC':
return [1] + s2d + [1] return [1] + s2d + [1]
else: else:
return [1, 1] + s2d return [1, 1] + s2d
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment