Commit 6e93970f authored by yg320's avatar yg320 Committed by Yuxin Wu

Transpose conv bug (#1326)

* Fixed a bug when calculating output shape of valid/full padded Conv2DTranspose

* Added valid padding test and nchw test to TestConv2DTranspose

* Removed NCHW test from TestConv2DTranspose (Conv2DCustomBackpropInputOp only supports NHWC)

* removed black line

* Added valid padding test back to TestConv2DTranspose

* tf.Dimension(None) is None never evaluates to True
parent 09364de3
......@@ -206,6 +206,7 @@ def Conv2DTranspose(
"Unsupported arguments due to Keras bug in TensorFlow 1.13"
data_format = get_data_format(data_format, keras_mode=False)
shape_dyn = tf.shape(inputs)
shape_sta = inputs.shape.as_list()
strides2d = shape2d(strides)
kernel_shape = shape2d(kernel_size)
......@@ -218,23 +219,23 @@ def Conv2DTranspose(
shape_res2d = shape2d(0)
if data_format == 'NCHW':
channels_in = inputs.shape[1]
channels_in = shape_sta[1]
out_shape_dyn = tf.stack(
[shape_dyn[0], filters,
shape_dyn[2] * strides2d[0] + shape_res2d[0],
shape_dyn[3] * strides2d[1] + shape_res2d[1]])
out_shape3_sta = [filters,
None if inputs.shape[2] is None else int(inputs.shape[2] * strides2d[0]) + shape_res2d[0],
None if inputs.shape[3] is None else int(inputs.shape[3] * strides2d[1]) + shape_res2d[1]]
None if shape_sta[2] is None else shape_sta[2] * strides2d[0] + shape_res2d[0],
None if shape_sta[3] is None else shape_sta[3] * strides2d[1] + shape_res2d[1]]
else:
channels_in = inputs.shape[-1]
channels_in = shape_sta[-1]
out_shape_dyn = tf.stack(
[shape_dyn[0],
shape_dyn[1] * strides2d[0] + shape_res2d[0],
shape_dyn[2] * strides2d[1] + shape_res2d[1],
filters])
out_shape3_sta = [None if inputs.shape[1] is None else int(inputs.shape[1] * strides2d[0]) + shape_res2d[0],
None if inputs.shape[2] is None else int(inputs.shape[2] * strides2d[1]) + shape_res2d[1],
out_shape3_sta = [None if shape_sta[1] is None else shape_sta[1] * strides2d[0] + shape_res2d[0],
None if shape_sta[2] is None else shape_sta[2] * strides2d[1] + shape_res2d[1],
filters]
W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment