Commit 09364de3 authored by yg320's avatar yg320 Committed by Yuxin Wu

Transpose Convolution Valid/Full Padding (#1323)

* Fixed a bug when calculating output shape of valid/full padded Conv2DTranspose

* Added valid padding test and nchw test to TestConv2DTranspose

* Removed NCHW test from TestConv2DTranspose (Conv2DCustomBackpropInputOp only supports NHWC)

* removed black line

* Added valid padding test back to TestConv2DTranspose
parent f7a79d48
...@@ -207,28 +207,36 @@ def Conv2DTranspose( ...@@ -207,28 +207,36 @@ def Conv2DTranspose(
data_format = get_data_format(data_format, keras_mode=False) data_format = get_data_format(data_format, keras_mode=False)
shape_dyn = tf.shape(inputs) shape_dyn = tf.shape(inputs)
strides2d = shape2d(strides) strides2d = shape2d(strides)
channels_in = inputs.shape[1 if data_format == 'NCHW' else 3] kernel_shape = shape2d(kernel_size)
assert padding.lower() in ['valid', 'same'], "Padding {} is not supported!".format(padding)
if padding.lower() == 'valid':
shape_res2d = [max(kernel_shape[0] - strides2d[0], 0),
max(kernel_shape[1] - strides2d[1], 0)]
else:
shape_res2d = shape2d(0)
if data_format == 'NCHW': if data_format == 'NCHW':
channels_in = inputs.shape[1] channels_in = inputs.shape[1]
out_shape_dyn = tf.stack( out_shape_dyn = tf.stack(
[shape_dyn[0], filters, [shape_dyn[0], filters,
shape_dyn[2] * strides2d[0], shape_dyn[2] * strides2d[0] + shape_res2d[0],
shape_dyn[3] * strides2d[1]]) shape_dyn[3] * strides2d[1] + shape_res2d[1]])
out_shape3_sta = [filters, out_shape3_sta = [filters,
None if inputs.shape[2] is None else inputs.shape[2] * strides2d[0], None if inputs.shape[2] is None else int(inputs.shape[2] * strides2d[0]) + shape_res2d[0],
None if inputs.shape[3] is None else inputs.shape[3] * strides2d[1]] None if inputs.shape[3] is None else int(inputs.shape[3] * strides2d[1]) + shape_res2d[1]]
else: else:
channels_in = inputs.shape[-1] channels_in = inputs.shape[-1]
out_shape_dyn = tf.stack( out_shape_dyn = tf.stack(
[shape_dyn[0], [shape_dyn[0],
shape_dyn[1] * strides2d[0], shape_dyn[1] * strides2d[0] + shape_res2d[0],
shape_dyn[2] * strides2d[1], shape_dyn[2] * strides2d[1] + shape_res2d[1],
filters]) filters])
out_shape3_sta = [None if inputs.shape[1] is None else inputs.shape[1] * strides2d[0], out_shape3_sta = [None if inputs.shape[1] is None else int(inputs.shape[1] * strides2d[0]) + shape_res2d[0],
None if inputs.shape[2] is None else inputs.shape[2] * strides2d[1], None if inputs.shape[2] is None else int(inputs.shape[2] * strides2d[1]) + shape_res2d[1],
filters] filters]
kernel_shape = shape2d(kernel_size)
W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer) W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer)
if use_bias: if use_bias:
b = tf.get_variable('b', [filters], initializer=bias_initializer) b = tf.get_variable('b', [filters], initializer=bias_initializer)
......
...@@ -76,7 +76,7 @@ class TestConv2DTranspose(TestModel): ...@@ -76,7 +76,7 @@ class TestConv2DTranspose(TestModel):
def test_shape_match(self): def test_shape_match(self):
h, w = 12, 18 h, w = 12, 18
input = self.make_variable(np.random.rand(1, h, w, 3).astype("float32")) input = self.make_variable(np.random.rand(1, h, w, 3).astype("float32"))
for padding in ["same"]: for padding in ["same", "valid"]:
for stride in [1, 2]: for stride in [1, 2]:
output = Conv2DTranspose( output = Conv2DTranspose(
'deconv_s{}_pad{}'.format(stride, padding), 'deconv_s{}_pad{}'.format(stride, padding),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment