Commit 4ac2e22b authored by gehuangyi20's avatar gehuangyi20 Committed by Yuxin Wu

fix input dtype not matching variable dtype (#1386)

* fix not input dtype not matching variable dtype

* fix input dtype not matching variable dtype

* Fix incorrect variable name input_dtype -> inputs_dtype

* Fix code format

* Fix code format
parent 12ad257c
...@@ -104,11 +104,13 @@ def Conv2D( ...@@ -104,11 +104,13 @@ def Conv2D(
if get_tf_version_tuple() >= (1, 5): if get_tf_version_tuple() >= (1, 5):
kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format) kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format)
# matching input dtype (ex. tf.float16) since the default dtype of variable if tf.float32
inputs_dtype = inputs.dtype
W = tf.get_variable( W = tf.get_variable(
'W', filter_shape, initializer=kernel_initializer) 'W', filter_shape, dtype=inputs_dtype, initializer=kernel_initializer)
if use_bias: if use_bias:
b = tf.get_variable('b', [out_channel], initializer=bias_initializer) b = tf.get_variable('b', [out_channel], dtype=inputs_dtype, initializer=bias_initializer)
if split == 1: if split == 1:
conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs) conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs)
...@@ -238,9 +240,11 @@ def Conv2DTranspose( ...@@ -238,9 +240,11 @@ def Conv2DTranspose(
None if shape_sta[2] is None else shape_sta[2] * strides2d[1] + shape_res2d[1], None if shape_sta[2] is None else shape_sta[2] * strides2d[1] + shape_res2d[1],
filters] filters]
W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer) inputs_dtype = inputs.dtype
W = tf.get_variable('W', kernel_shape + [filters, channels_in],
dtype=inputs_dtype, initializer=kernel_initializer)
if use_bias: if use_bias:
b = tf.get_variable('b', [filters], initializer=bias_initializer) b = tf.get_variable('b', [filters], dtype=inputs_dtype, initializer=bias_initializer)
conv = tf.nn.conv2d_transpose( conv = tf.nn.conv2d_transpose(
inputs, W, out_shape_dyn, inputs, W, out_shape_dyn,
shape4d(strides, data_format=data_format), shape4d(strides, data_format=data_format),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment