Commit 6e93970f authored by yg320's avatar yg320 Committed by Yuxin Wu

Transpose conv bug (#1326)

* Fixed a bug when calculating output shape of valid/full padded Conv2DTranspose

* Added valid padding test and nchw test to TestConv2DTranspose

* Removed NCHW test from TestConv2DTranspose (Conv2DCustomBackpropInputOp only supports NHWC)

* removed black line

* Added valid padding test back to TestConv2DTranspose

* tf.Dimension(None) is None never evaluates to True
parent 09364de3
...@@ -206,6 +206,7 @@ def Conv2DTranspose( ...@@ -206,6 +206,7 @@ def Conv2DTranspose(
"Unsupported arguments due to Keras bug in TensorFlow 1.13" "Unsupported arguments due to Keras bug in TensorFlow 1.13"
data_format = get_data_format(data_format, keras_mode=False) data_format = get_data_format(data_format, keras_mode=False)
shape_dyn = tf.shape(inputs) shape_dyn = tf.shape(inputs)
shape_sta = inputs.shape.as_list()
strides2d = shape2d(strides) strides2d = shape2d(strides)
kernel_shape = shape2d(kernel_size) kernel_shape = shape2d(kernel_size)
...@@ -218,23 +219,23 @@ def Conv2DTranspose( ...@@ -218,23 +219,23 @@ def Conv2DTranspose(
shape_res2d = shape2d(0) shape_res2d = shape2d(0)
if data_format == 'NCHW': if data_format == 'NCHW':
channels_in = inputs.shape[1] channels_in = shape_sta[1]
out_shape_dyn = tf.stack( out_shape_dyn = tf.stack(
[shape_dyn[0], filters, [shape_dyn[0], filters,
shape_dyn[2] * strides2d[0] + shape_res2d[0], shape_dyn[2] * strides2d[0] + shape_res2d[0],
shape_dyn[3] * strides2d[1] + shape_res2d[1]]) shape_dyn[3] * strides2d[1] + shape_res2d[1]])
out_shape3_sta = [filters, out_shape3_sta = [filters,
None if inputs.shape[2] is None else int(inputs.shape[2] * strides2d[0]) + shape_res2d[0], None if shape_sta[2] is None else shape_sta[2] * strides2d[0] + shape_res2d[0],
None if inputs.shape[3] is None else int(inputs.shape[3] * strides2d[1]) + shape_res2d[1]] None if shape_sta[3] is None else shape_sta[3] * strides2d[1] + shape_res2d[1]]
else: else:
channels_in = inputs.shape[-1] channels_in = shape_sta[-1]
out_shape_dyn = tf.stack( out_shape_dyn = tf.stack(
[shape_dyn[0], [shape_dyn[0],
shape_dyn[1] * strides2d[0] + shape_res2d[0], shape_dyn[1] * strides2d[0] + shape_res2d[0],
shape_dyn[2] * strides2d[1] + shape_res2d[1], shape_dyn[2] * strides2d[1] + shape_res2d[1],
filters]) filters])
out_shape3_sta = [None if inputs.shape[1] is None else int(inputs.shape[1] * strides2d[0]) + shape_res2d[0], out_shape3_sta = [None if shape_sta[1] is None else shape_sta[1] * strides2d[0] + shape_res2d[0],
None if inputs.shape[2] is None else int(inputs.shape[2] * strides2d[1]) + shape_res2d[1], None if shape_sta[2] is None else shape_sta[2] * strides2d[1] + shape_res2d[1],
filters] filters]
W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer) W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment