Commit 9f790afc authored by Yuxin Wu's avatar Yuxin Wu

use tf.layers for Deconv implementation

parent 2b027291
......@@ -42,10 +42,10 @@ class Model(GANModelDesc):
y = tf.reshape(y, [-1, 1, 1, 10])
l = tf.concat([l, tf.tile(y, [1, 7, 7, 1])], 3)
l = Deconv2D('deconv1', l, [14, 14, 64 * 2], 5, 2, nl=BNReLU)
l = Deconv2D('deconv1', l, 64 * 2, 5, 2, nl=BNReLU)
l = tf.concat([l, tf.tile(y, [1, 14, 14, 1])], 3)
l = Deconv2D('deconv2', l, [28, 28, 1], 5, 2, nl=tf.identity)
l = Deconv2D('deconv2', l, 1, 5, 2, nl=tf.identity)
l = tf.nn.tanh(l, name='gen')
return l
......
......@@ -51,10 +51,10 @@ class Model(GANModelDesc):
l = tf.reshape(l, [-1, 4, 4, nf * 8])
l = BNReLU(l)
with argscope(Deconv2D, nl=BNReLU, kernel_shape=4, stride=2):
l = Deconv2D('deconv1', l, [8, 8, nf * 4])
l = Deconv2D('deconv2', l, [16, 16, nf * 2])
l = Deconv2D('deconv3', l, [32, 32, nf])
l = Deconv2D('deconv4', l, [64, 64, 3], nl=tf.identity)
l = Deconv2D('deconv1', l, nf * 4)
l = Deconv2D('deconv2', l, nf * 2)
l = Deconv2D('deconv3', l, nf)
l = Deconv2D('deconv4', l, 3, nl=tf.identity)
l = tf.tanh(l, name='gen')
return l
......
......@@ -52,8 +52,8 @@ class Model(GANModelDesc):
l = FullyConnected('fc0', z, 1024, nl=BNReLU)
l = FullyConnected('fc1', l, 128 * 7 * 7, nl=BNReLU)
l = tf.reshape(l, [-1, 7, 7, 128])
l = Deconv2D('deconv1', l, [14, 14, 64], 4, 2, nl=BNReLU)
l = Deconv2D('deconv2', l, [28, 28, 1], 4, 2, nl=tf.identity)
l = Deconv2D('deconv1', l, 64, 4, 2, nl=BNReLU)
l = Deconv2D('deconv2', l, 1, 4, 2, nl=tf.identity)
l = tf.sigmoid(l, name='gen')
return l
......
......@@ -4,9 +4,9 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from .common import layer_register, VariableHolder
from .common import layer_register, VariableHolder, rename_get_variable
from ..utils.argtools import shape2d, shape4d
from .shape_utils import StaticDynamicAxis
from ..utils.develop import log_deprecated
__all__ = ['Conv2D', 'Deconv2D']
......@@ -80,7 +80,7 @@ def Conv2D(x, out_channel, kernel_shape,
@layer_register(log_shape=True)
def Deconv2D(x, out_shape, kernel_shape,
def Deconv2D(x, out_channel, kernel_shape,
stride, padding='SAME',
W_init=None, b_init=None,
nl=tf.identity, use_bias=True,
......@@ -91,8 +91,7 @@ def Deconv2D(x, out_shape, kernel_shape,
Args:
x (tf.Tensor): a tensor of shape NHWC.
Must have known number of channels, but can have other unknown dimensions.
out_shape: (h, w, channel) tuple, or just a integer channel,
then (h, w) will be calculated by input_shape * stride
out_channel: the output number of channel.
kernel_shape: (h, w) tuple or a int.
stride: (h, w) tuple or a int.
padding (str): 'valid' or 'same'. Case insensitive.
......@@ -113,47 +112,41 @@ def Deconv2D(x, out_shape, kernel_shape,
channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis]
assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"
kernel_shape = shape2d(kernel_shape)
stride2d = shape2d(stride)
stride4d = shape4d(stride, data_format=data_format)
padding = padding.upper()
in_shape_dyn = tf.shape(x)
out_shape = out_channel
if isinstance(out_shape, int):
out_channel = out_shape
if data_format == 'NHWC':
shp3_0 = StaticDynamicAxis(in_shape[1], in_shape_dyn[1]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicAxis(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[1] * x)
shp3_dyn = [shp3_0.dynamic, shp3_1.dynamic, out_channel]
shp3_static = [shp3_0.static, shp3_1.static, out_channel]
else:
shp3_0 = StaticDynamicAxis(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicAxis(in_shape[3], in_shape_dyn[3]).apply(lambda x: stride2d[1] * x)
shp3_dyn = [out_channel, shp3_0.dynamic, shp3_1.dynamic]
shp3_static = [out_channel, shp3_0.static, shp3_1.static]
else:
log_deprecated("Deconv2D(out_shape=[...])",
"Use an integer 'out_channel' instead!", "2017-11-18")
for k in out_shape:
if not isinstance(k, int):
raise ValueError("[Deconv2D] out_shape {} is invalid!".format(k))
out_channel = out_shape[channel_axis - 1] # out_shape doesn't have batch
shp3_static = shp3_dyn = out_shape
filter_shape = kernel_shape + [out_channel, in_channel]
if W_init is None:
W_init = tf.contrib.layers.xavier_initializer_conv2d()
if b_init is None:
b_init = tf.constant_initializer()
W = tf.get_variable('W', filter_shape, initializer=W_init)
if use_bias:
b = tf.get_variable('b', [out_channel], initializer=b_init)
out_shape_dyn = tf.stack([tf.shape(x)[0]] + shp3_dyn)
conv = tf.nn.conv2d_transpose(
x, W, out_shape_dyn, stride4d, padding=padding, data_format=data_format)
conv.set_shape(tf.TensorShape([None] + shp3_static))
ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret.variables = VariableHolder(W=W)
with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
layer = tf.layers.Deconv2D(
out_channel, kernel_shape,
strides=stride, padding=padding,
data_format='channels_last' if data_format == 'NHWC' else 'channels_first',
activation=lambda x: nl(x, name='output'),
use_bias=use_bias,
kernel_initializer=W_init,
bias_initializer=b_init,
trainable=True)
ret = layer.apply(x, scope=tf.get_variable_scope())
# Check that we only supports out_shape = in_shape * stride
out_shape3 = ret.get_shape().as_list()[1:]
if not isinstance(out_shape, int):
assert list(out_shape) == out_shape3, "{} != {}".format(out_shape, out_shape3)
ret.variables = VariableHolder(W=layer.kernel)
if use_bias:
ret.variables.b = b
ret.variables.b = layer.bias
return ret
......@@ -66,5 +66,8 @@ def monkeypatch_tf_layers():
from tensorflow.python.layers.normalization import BatchNormalization
tf.layers.BatchNormalization = BatchNormalization
from tensorflow.python.layers.convolutional import Deconv2D
tf.layers.Deconv2D = Deconv2D
monkeypatch_tf_layers()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment