Commit 82330eb5 authored by Yuxin Wu's avatar Yuxin Wu

use tf.layers.Conv2DTranspose (#291)

parent 9f790afc
...@@ -130,7 +130,7 @@ def Deconv2D(x, out_channel, kernel_shape, ...@@ -130,7 +130,7 @@ def Deconv2D(x, out_channel, kernel_shape,
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
with rename_get_variable({'kernel': 'W', 'bias': 'b'}): with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
layer = tf.layers.Deconv2D( layer = tf.layers.Conv2DTranspose(
out_channel, kernel_shape, out_channel, kernel_shape,
strides=stride, padding=padding, strides=stride, padding=padding,
data_format='channels_last' if data_format == 'NHWC' else 'channels_first', data_format='channels_last' if data_format == 'NHWC' else 'channels_first',
......
...@@ -66,8 +66,8 @@ def monkeypatch_tf_layers(): ...@@ -66,8 +66,8 @@ def monkeypatch_tf_layers():
from tensorflow.python.layers.normalization import BatchNormalization from tensorflow.python.layers.normalization import BatchNormalization
tf.layers.BatchNormalization = BatchNormalization tf.layers.BatchNormalization = BatchNormalization
from tensorflow.python.layers.convolutional import Deconv2D from tensorflow.python.layers.convolutional import Conv2DTranspose
tf.layers.Deconv2D = Deconv2D tf.layers.Conv2DTranspose = Conv2DTranspose
monkeypatch_tf_layers() monkeypatch_tf_layers()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment