Commit 9f790afc authored by Yuxin Wu's avatar Yuxin Wu

use tf.layers for Deconv implementation

parent 2b027291
...@@ -42,10 +42,10 @@ class Model(GANModelDesc): ...@@ -42,10 +42,10 @@ class Model(GANModelDesc):
y = tf.reshape(y, [-1, 1, 1, 10]) y = tf.reshape(y, [-1, 1, 1, 10])
l = tf.concat([l, tf.tile(y, [1, 7, 7, 1])], 3) l = tf.concat([l, tf.tile(y, [1, 7, 7, 1])], 3)
l = Deconv2D('deconv1', l, [14, 14, 64 * 2], 5, 2, nl=BNReLU) l = Deconv2D('deconv1', l, 64 * 2, 5, 2, nl=BNReLU)
l = tf.concat([l, tf.tile(y, [1, 14, 14, 1])], 3) l = tf.concat([l, tf.tile(y, [1, 14, 14, 1])], 3)
l = Deconv2D('deconv2', l, [28, 28, 1], 5, 2, nl=tf.identity) l = Deconv2D('deconv2', l, 1, 5, 2, nl=tf.identity)
l = tf.nn.tanh(l, name='gen') l = tf.nn.tanh(l, name='gen')
return l return l
......
...@@ -51,10 +51,10 @@ class Model(GANModelDesc): ...@@ -51,10 +51,10 @@ class Model(GANModelDesc):
l = tf.reshape(l, [-1, 4, 4, nf * 8]) l = tf.reshape(l, [-1, 4, 4, nf * 8])
l = BNReLU(l) l = BNReLU(l)
with argscope(Deconv2D, nl=BNReLU, kernel_shape=4, stride=2): with argscope(Deconv2D, nl=BNReLU, kernel_shape=4, stride=2):
l = Deconv2D('deconv1', l, [8, 8, nf * 4]) l = Deconv2D('deconv1', l, nf * 4)
l = Deconv2D('deconv2', l, [16, 16, nf * 2]) l = Deconv2D('deconv2', l, nf * 2)
l = Deconv2D('deconv3', l, [32, 32, nf]) l = Deconv2D('deconv3', l, nf)
l = Deconv2D('deconv4', l, [64, 64, 3], nl=tf.identity) l = Deconv2D('deconv4', l, 3, nl=tf.identity)
l = tf.tanh(l, name='gen') l = tf.tanh(l, name='gen')
return l return l
......
...@@ -52,8 +52,8 @@ class Model(GANModelDesc): ...@@ -52,8 +52,8 @@ class Model(GANModelDesc):
l = FullyConnected('fc0', z, 1024, nl=BNReLU) l = FullyConnected('fc0', z, 1024, nl=BNReLU)
l = FullyConnected('fc1', l, 128 * 7 * 7, nl=BNReLU) l = FullyConnected('fc1', l, 128 * 7 * 7, nl=BNReLU)
l = tf.reshape(l, [-1, 7, 7, 128]) l = tf.reshape(l, [-1, 7, 7, 128])
l = Deconv2D('deconv1', l, [14, 14, 64], 4, 2, nl=BNReLU) l = Deconv2D('deconv1', l, 64, 4, 2, nl=BNReLU)
l = Deconv2D('deconv2', l, [28, 28, 1], 4, 2, nl=tf.identity) l = Deconv2D('deconv2', l, 1, 4, 2, nl=tf.identity)
l = tf.sigmoid(l, name='gen') l = tf.sigmoid(l, name='gen')
return l return l
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder, rename_get_variable
from ..utils.argtools import shape2d, shape4d from ..utils.argtools import shape2d, shape4d
from .shape_utils import StaticDynamicAxis from ..utils.develop import log_deprecated
__all__ = ['Conv2D', 'Deconv2D'] __all__ = ['Conv2D', 'Deconv2D']
...@@ -80,7 +80,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -80,7 +80,7 @@ def Conv2D(x, out_channel, kernel_shape,
@layer_register(log_shape=True) @layer_register(log_shape=True)
def Deconv2D(x, out_shape, kernel_shape, def Deconv2D(x, out_channel, kernel_shape,
stride, padding='SAME', stride, padding='SAME',
W_init=None, b_init=None, W_init=None, b_init=None,
nl=tf.identity, use_bias=True, nl=tf.identity, use_bias=True,
...@@ -91,8 +91,7 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -91,8 +91,7 @@ def Deconv2D(x, out_shape, kernel_shape,
Args: Args:
x (tf.Tensor): a tensor of shape NHWC. x (tf.Tensor): a tensor of shape NHWC.
Must have known number of channels, but can have other unknown dimensions. Must have known number of channels, but can have other unknown dimensions.
out_shape: (h, w, channel) tuple, or just a integer channel, out_channel: the output number of channel.
then (h, w) will be calculated by input_shape * stride
kernel_shape: (h, w) tuple or a int. kernel_shape: (h, w) tuple or a int.
stride: (h, w) tuple or a int. stride: (h, w) tuple or a int.
padding (str): 'valid' or 'same'. Case insensitive. padding (str): 'valid' or 'same'. Case insensitive.
...@@ -113,47 +112,41 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -113,47 +112,41 @@ def Deconv2D(x, out_shape, kernel_shape,
channel_axis = 3 if data_format == 'NHWC' else 1 channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis] in_channel = in_shape[channel_axis]
assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!" assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"
kernel_shape = shape2d(kernel_shape)
stride2d = shape2d(stride)
stride4d = shape4d(stride, data_format=data_format)
padding = padding.upper()
in_shape_dyn = tf.shape(x)
out_shape = out_channel
if isinstance(out_shape, int): if isinstance(out_shape, int):
out_channel = out_shape out_channel = out_shape
if data_format == 'NHWC':
shp3_0 = StaticDynamicAxis(in_shape[1], in_shape_dyn[1]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicAxis(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[1] * x)
shp3_dyn = [shp3_0.dynamic, shp3_1.dynamic, out_channel]
shp3_static = [shp3_0.static, shp3_1.static, out_channel]
else:
shp3_0 = StaticDynamicAxis(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicAxis(in_shape[3], in_shape_dyn[3]).apply(lambda x: stride2d[1] * x)
shp3_dyn = [out_channel, shp3_0.dynamic, shp3_1.dynamic]
shp3_static = [out_channel, shp3_0.static, shp3_1.static]
else: else:
log_deprecated("Deconv2D(out_shape=[...])",
"Use an integer 'out_channel' instead!", "2017-11-18")
for k in out_shape: for k in out_shape:
if not isinstance(k, int): if not isinstance(k, int):
raise ValueError("[Deconv2D] out_shape {} is invalid!".format(k)) raise ValueError("[Deconv2D] out_shape {} is invalid!".format(k))
out_channel = out_shape[channel_axis - 1] # out_shape doesn't have batch out_channel = out_shape[channel_axis - 1] # out_shape doesn't have batch
shp3_static = shp3_dyn = out_shape
filter_shape = kernel_shape + [out_channel, in_channel]
if W_init is None: if W_init is None:
W_init = tf.contrib.layers.xavier_initializer_conv2d() W_init = tf.contrib.layers.xavier_initializer_conv2d()
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
W = tf.get_variable('W', filter_shape, initializer=W_init)
if use_bias:
b = tf.get_variable('b', [out_channel], initializer=b_init)
out_shape_dyn = tf.stack([tf.shape(x)[0]] + shp3_dyn)
conv = tf.nn.conv2d_transpose(
x, W, out_shape_dyn, stride4d, padding=padding, data_format=data_format)
conv.set_shape(tf.TensorShape([None] + shp3_static))
ret = nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret.variables = VariableHolder(W=W) with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
layer = tf.layers.Deconv2D(
out_channel, kernel_shape,
strides=stride, padding=padding,
data_format='channels_last' if data_format == 'NHWC' else 'channels_first',
activation=lambda x: nl(x, name='output'),
use_bias=use_bias,
kernel_initializer=W_init,
bias_initializer=b_init,
trainable=True)
ret = layer.apply(x, scope=tf.get_variable_scope())
# Check that we only supports out_shape = in_shape * stride
out_shape3 = ret.get_shape().as_list()[1:]
if not isinstance(out_shape, int):
assert list(out_shape) == out_shape3, "{} != {}".format(out_shape, out_shape3)
ret.variables = VariableHolder(W=layer.kernel)
if use_bias: if use_bias:
ret.variables.b = b ret.variables.b = layer.bias
return ret return ret
...@@ -66,5 +66,8 @@ def monkeypatch_tf_layers(): ...@@ -66,5 +66,8 @@ def monkeypatch_tf_layers():
from tensorflow.python.layers.normalization import BatchNormalization from tensorflow.python.layers.normalization import BatchNormalization
tf.layers.BatchNormalization = BatchNormalization tf.layers.BatchNormalization = BatchNormalization
from tensorflow.python.layers.convolutional import Deconv2D
tf.layers.Deconv2D = Deconv2D
monkeypatch_tf_layers() monkeypatch_tf_layers()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment