Commit fe33c833 authored by Yuxin Wu's avatar Yuxin Wu

Use tf.layers arguments in GAN examples.

parent 4744853b
...@@ -31,10 +31,10 @@ class Model(GANModelDesc): ...@@ -31,10 +31,10 @@ class Model(GANModelDesc):
@auto_reuse_variable_scope @auto_reuse_variable_scope
def decoder(self, z): def decoder(self, z):
l = FullyConnected('fc', z, NF * 8 * 8, nl=tf.identity) l = FullyConnected('fc', z, NF * 8 * 8)
l = tf.reshape(l, [-1, 8, 8, NF]) l = tf.reshape(l, [-1, 8, 8, NF])
with argscope(Conv2D, nl=tf.nn.elu, kernel_shape=3, stride=1): with argscope(Conv2D, activation=tf.nn.elu, kernel_size=3, strides=1):
l = (LinearWrap(l) l = (LinearWrap(l)
.Conv2D('conv1.1', NF) .Conv2D('conv1.1', NF)
.Conv2D('conv1.2', NF) .Conv2D('conv1.2', NF)
...@@ -47,12 +47,12 @@ class Model(GANModelDesc): ...@@ -47,12 +47,12 @@ class Model(GANModelDesc):
.tf.image.resize_nearest_neighbor([64, 64], align_corners=True) .tf.image.resize_nearest_neighbor([64, 64], align_corners=True)
.Conv2D('conv4.1', NF) .Conv2D('conv4.1', NF)
.Conv2D('conv4.2', NF) .Conv2D('conv4.2', NF)
.Conv2D('conv4.3', 3, nl=tf.identity)()) .Conv2D('conv4.3', 3, activation=tf.identity)())
return l return l
@auto_reuse_variable_scope @auto_reuse_variable_scope
def encoder(self, imgs): def encoder(self, imgs):
with argscope(Conv2D, nl=tf.nn.elu, kernel_shape=3, stride=1): with argscope(Conv2D, activation=tf.nn.elu, kernel_size=3, strides=1):
l = (LinearWrap(imgs) l = (LinearWrap(imgs)
.Conv2D('conv1.1', NF) .Conv2D('conv1.1', NF)
.Conv2D('conv1.2', NF) .Conv2D('conv1.2', NF)
...@@ -70,7 +70,7 @@ class Model(GANModelDesc): ...@@ -70,7 +70,7 @@ class Model(GANModelDesc):
.Conv2D('conv4.1', NF * 4) .Conv2D('conv4.1', NF * 4)
.Conv2D('conv4.2', NF * 4) .Conv2D('conv4.2', NF * 4)
.FullyConnected('fc', NH, nl=tf.identity)()) .FullyConnected('fc', NH)())
return l return l
def _build_graph(self, inputs): def _build_graph(self, inputs):
...@@ -86,7 +86,7 @@ class Model(GANModelDesc): ...@@ -86,7 +86,7 @@ class Model(GANModelDesc):
tf.summary.image(name, tf.cast(x, tf.uint8), max_outputs=30) tf.summary.image(name, tf.cast(x, tf.uint8), max_outputs=30)
with argscope([Conv2D, FullyConnected], with argscope([Conv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
image_gen = self.decoder(z) image_gen = self.decoder(z)
......
...@@ -45,16 +45,16 @@ class Model(GANModelDesc): ...@@ -45,16 +45,16 @@ class Model(GANModelDesc):
InputDesc(tf.int32, (None,), 'label')] InputDesc(tf.int32, (None,), 'label')]
def generator(self, z, y): def generator(self, z, y):
l = FullyConnected('fc0', tf.concat([z, y], 1), 1024, nl=BNReLU) l = FullyConnected('fc0', tf.concat([z, y], 1), 1024, activation=BNReLU)
l = FullyConnected('fc1', tf.concat([l, y], 1), 64 * 2 * 7 * 7, nl=BNReLU) l = FullyConnected('fc1', tf.concat([l, y], 1), 64 * 2 * 7 * 7, activation=BNReLU)
l = tf.reshape(l, [-1, 7, 7, 64 * 2]) l = tf.reshape(l, [-1, 7, 7, 64 * 2])
y = tf.reshape(y, [-1, 1, 1, 10]) y = tf.reshape(y, [-1, 1, 1, 10])
l = tf.concat([l, tf.tile(y, [1, 7, 7, 1])], 3) l = tf.concat([l, tf.tile(y, [1, 7, 7, 1])], 3)
l = Deconv2D('deconv1', l, 64 * 2, 5, 2, nl=BNReLU) l = Conv2DTranspose('deconv1', l, 64 * 2, 5, 2, activation=BNReLU)
l = tf.concat([l, tf.tile(y, [1, 14, 14, 1])], 3) l = tf.concat([l, tf.tile(y, [1, 14, 14, 1])], 3)
l = Deconv2D('deconv2', l, 1, 5, 2, nl=tf.identity) l = Conv2DTranspose('deconv2', l, 1, 5, 2, activation=tf.identity)
l = tf.nn.tanh(l, name='gen') l = tf.nn.tanh(l, name='gen')
return l return l
...@@ -63,7 +63,7 @@ class Model(GANModelDesc): ...@@ -63,7 +63,7 @@ class Model(GANModelDesc):
""" return a (b, 1) logits""" """ return a (b, 1) logits"""
yv = y yv = y
y = tf.reshape(y, [-1, 1, 1, 10]) y = tf.reshape(y, [-1, 1, 1, 10])
with argscope(Conv2D, nl=tf.identity, kernel_shape=5, stride=2): with argscope(Conv2D, kernel_size=5, strides=1):
l = (LinearWrap(imgs) l = (LinearWrap(imgs)
.ConcatWith(tf.tile(y, [1, 28, 28, 1]), 3) .ConcatWith(tf.tile(y, [1, 28, 28, 1]), 3)
.Conv2D('conv0', 11) .Conv2D('conv0', 11)
...@@ -76,12 +76,12 @@ class Model(GANModelDesc): ...@@ -76,12 +76,12 @@ class Model(GANModelDesc):
.apply(batch_flatten) .apply(batch_flatten)
.ConcatWith(yv, 1) .ConcatWith(yv, 1)
.FullyConnected('fc1', 1024, nl=tf.identity) .FullyConnected('fc1', 1024, activation=tf.identity)
.BatchNorm('bn2') .BatchNorm('bn2')
.tf.nn.leaky_relu() .tf.nn.leaky_relu()
.ConcatWith(yv, 1) .ConcatWith(yv, 1)
.FullyConnected('fct', 1, nl=tf.identity)()) .FullyConnected('fct', 1, activation=tf.identity)())
return l return l
def _build_graph(self, inputs): def _build_graph(self, inputs):
...@@ -92,8 +92,8 @@ class Model(GANModelDesc): ...@@ -92,8 +92,8 @@ class Model(GANModelDesc):
z = tf.random_uniform([BATCH, 100], -1, 1, name='z_train') z = tf.random_uniform([BATCH, 100], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, 100], name='z') # clear the static shape z = tf.placeholder_with_default(z, [None, 100], name='z') # clear the static shape
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Conv2DTranspose, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
image_gen = self.generator(z, y) image_gen = self.generator(z, y)
tf.summary.image('gen', image_gen, 30) tf.summary.image('gen', image_gen, 30)
......
...@@ -51,38 +51,38 @@ class Model(GANModelDesc): ...@@ -51,38 +51,38 @@ class Model(GANModelDesc):
input = x input = x
return (LinearWrap(x) return (LinearWrap(x)
.tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC') .tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
.Conv2D('conv0', chan, padding='VALID') .Conv2D('conv0', chan, 3, padding='VALID')
.tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC') .tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
.Conv2D('conv1', chan, padding='VALID', nl=tf.identity) .Conv2D('conv1', chan, 3, padding='VALID', activation=tf.identity)
.InstanceNorm('inorm')()) + input .InstanceNorm('inorm')()) + input
@auto_reuse_variable_scope @auto_reuse_variable_scope
def generator(self, img): def generator(self, img):
assert img is not None assert img is not None
with argscope([Conv2D, Deconv2D], nl=INReLU, kernel_shape=3): with argscope([Conv2D, Conv2DTranspose], activation=INReLU):
l = (LinearWrap(img) l = (LinearWrap(img)
.tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC') .tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC')
.Conv2D('conv0', NF, kernel_shape=7, padding='VALID') .Conv2D('conv0', NF, 7, padding='VALID')
.Conv2D('conv1', NF * 2, stride=2) .Conv2D('conv1', NF * 2, 3, strides=2)
.Conv2D('conv2', NF * 4, stride=2)()) .Conv2D('conv2', NF * 4, 3, strides=2)())
for k in range(9): for k in range(9):
l = Model.build_res_block(l, 'res{}'.format(k), NF * 4, first=(k == 0)) l = Model.build_res_block(l, 'res{}'.format(k), NF * 4, first=(k == 0))
l = (LinearWrap(l) l = (LinearWrap(l)
.Deconv2D('deconv0', NF * 2, stride=2) .Conv2DTranspose('deconv0', NF * 2, 3, strides=2)
.Deconv2D('deconv1', NF * 1, stride=2) .Conv2DTranspose('deconv1', NF * 1, 3, strides=2)
.tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC') .tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC')
.Conv2D('convlast', 3, kernel_shape=7, padding='VALID', nl=tf.tanh, use_bias=True)()) .Conv2D('convlast', 3, 7, padding='VALID', activation=tf.tanh, use_bias=True)())
return l return l
@auto_reuse_variable_scope @auto_reuse_variable_scope
def discriminator(self, img): def discriminator(self, img):
with argscope(Conv2D, nl=INLReLU, kernel_shape=4, stride=2): with argscope(Conv2D, activation=INLReLU, kernel_size=4, strides=2):
l = (LinearWrap(img) l = (LinearWrap(img)
.Conv2D('conv0', NF, nl=tf.nn.leaky_relu) .Conv2D('conv0', NF, activation=tf.nn.leaky_relu)
.Conv2D('conv1', NF * 2) .Conv2D('conv1', NF * 2)
.Conv2D('conv2', NF * 4) .Conv2D('conv2', NF * 4)
.Conv2D('conv3', NF * 8, stride=1) .Conv2D('conv3', NF * 8, strides=1)
.Conv2D('conv4', 1, stride=1, nl=tf.identity, use_bias=True)()) .Conv2D('conv4', 1, strides=1, activation=tf.identity, use_bias=True)())
return l return l
def _build_graph(self, inputs): def _build_graph(self, inputs):
...@@ -101,9 +101,9 @@ class Model(GANModelDesc): ...@@ -101,9 +101,9 @@ class Model(GANModelDesc):
tf.summary.image(name, im, max_outputs=50) tf.summary.image(name, im, max_outputs=50)
# use the initializers from torch # use the initializers from torch
with argscope([Conv2D, Deconv2D], use_bias=False, with argscope([Conv2D, Conv2DTranspose], use_bias=False,
W_init=tf.random_normal_initializer(stddev=0.02)), \ kernel_initializer=tf.random_normal_initializer(stddev=0.02)), \
argscope([Conv2D, Deconv2D, InstanceNorm], data_format='NCHW'): argscope([Conv2D, Conv2DTranspose, InstanceNorm], data_format='channels_first'):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
with tf.variable_scope('B'): with tf.variable_scope('B'):
AB = self.generator(A) AB = self.generator(A)
...@@ -211,10 +211,11 @@ if __name__ == '__main__': ...@@ -211,10 +211,11 @@ if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
data = get_data(args.data) df = get_data(args.data)
data = PrintData(data) df = PrintData(df)
data = StagingInput(QueueInput(df))
GANTrainer(QueueInput(data), Model()).train_with_defaults( GANTrainer(data, Model()).train_with_defaults(
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
......
...@@ -46,14 +46,14 @@ class Model(GANModelDesc): ...@@ -46,14 +46,14 @@ class Model(GANModelDesc):
def generator(self, z): def generator(self, z):
""" return an image generated from z""" """ return an image generated from z"""
nf = 64 nf = 64
l = FullyConnected('fc0', z, nf * 8 * 4 * 4, nl=tf.identity) l = FullyConnected('fc0', z, nf * 8 * 4 * 4, activation=tf.identity)
l = tf.reshape(l, [-1, 4, 4, nf * 8]) l = tf.reshape(l, [-1, 4, 4, nf * 8])
l = BNReLU(l) l = BNReLU(l)
with argscope(Deconv2D, nl=BNReLU, kernel_shape=4, stride=2): with argscope(Conv2DTranspose, activation=BNReLU, kernel_size=4, strides=2):
l = Deconv2D('deconv1', l, nf * 4) l = Conv2DTranspose('deconv1', l, nf * 4)
l = Deconv2D('deconv2', l, nf * 2) l = Conv2DTranspose('deconv2', l, nf * 2)
l = Deconv2D('deconv3', l, nf) l = Conv2DTranspose('deconv3', l, nf)
l = Deconv2D('deconv4', l, 3, nl=tf.identity) l = Conv2DTranspose('deconv4', l, 3, activation=tf.identity)
l = tf.tanh(l, name='gen') l = tf.tanh(l, name='gen')
return l return l
...@@ -61,9 +61,9 @@ class Model(GANModelDesc): ...@@ -61,9 +61,9 @@ class Model(GANModelDesc):
def discriminator(self, imgs): def discriminator(self, imgs):
""" return a (b, 1) logits""" """ return a (b, 1) logits"""
nf = 64 nf = 64
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2): with argscope(Conv2D, kernel_size=4, strides=2):
l = (LinearWrap(imgs) l = (LinearWrap(imgs)
.Conv2D('conv0', nf, nl=tf.nn.leaky_relu) .Conv2D('conv0', nf, activation=tf.nn.leaky_relu)
.Conv2D('conv1', nf * 2) .Conv2D('conv1', nf * 2)
.BatchNorm('bn1') .BatchNorm('bn1')
.tf.nn.leaky_relu() .tf.nn.leaky_relu()
...@@ -73,7 +73,7 @@ class Model(GANModelDesc): ...@@ -73,7 +73,7 @@ class Model(GANModelDesc):
.Conv2D('conv3', nf * 8) .Conv2D('conv3', nf * 8)
.BatchNorm('bn3') .BatchNorm('bn3')
.tf.nn.leaky_relu() .tf.nn.leaky_relu()
.FullyConnected('fct', 1, nl=tf.identity)()) .FullyConnected('fct', 1)())
return l return l
def _build_graph(self, inputs): def _build_graph(self, inputs):
...@@ -83,8 +83,8 @@ class Model(GANModelDesc): ...@@ -83,8 +83,8 @@ class Model(GANModelDesc):
z = tf.random_uniform([self.batch, self.zdim], -1, 1, name='z_train') z = tf.random_uniform([self.batch, self.zdim], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, self.zdim], name='z') z = tf.placeholder_with_default(z, [None, self.zdim], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Conv2DTranspose, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
image_gen = self.generator(z) image_gen = self.generator(z)
tf.summary.image('generated-samples', image_gen, max_outputs=30) tf.summary.image('generated-samples', image_gen, max_outputs=30)
......
...@@ -21,13 +21,6 @@ from GAN import SeparateGANTrainer, GANModelDesc ...@@ -21,13 +21,6 @@ from GAN import SeparateGANTrainer, GANModelDesc
3. Start training gender transfer: 3. Start training gender transfer:
./DiscoGAN-CelebA.py --data /path/to/img_align_celeba --style-A Male ./DiscoGAN-CelebA.py --data /path/to/img_align_celeba --style-A Male
4. Visualize the gender conversion images in tensorboard. 4. Visualize the gender conversion images in tensorboard.
With TF1.0.1, cuda 8.0, cudnn 5.1.10,
the training on 64x64 images of batch 64 runs 5.4 it/s on Tesla M40.
This is 2.4x as fast as the original PyTorch implementation.
The cause is probably that in the torch implementation,
a backward() computes gradients for ALL parameters, which is not necessary in GAN.
""" """
SHAPE = 64 SHAPE = 64
...@@ -48,29 +41,29 @@ class Model(GANModelDesc): ...@@ -48,29 +41,29 @@ class Model(GANModelDesc):
@auto_reuse_variable_scope @auto_reuse_variable_scope
def generator(self, img): def generator(self, img):
assert img is not None assert img is not None
with argscope([Conv2D, Deconv2D], with argscope([Conv2D, Conv2DTranspose],
nl=BNLReLU, kernel_shape=4, stride=2), \ activation=BNLReLU, kernel_size=4, strides=2), \
argscope(Deconv2D, nl=BNReLU): argscope(Conv2DTranspose, activation=BNReLU):
l = (LinearWrap(img) l = (LinearWrap(img)
.Conv2D('conv0', NF, nl=tf.nn.leaky_relu) .Conv2D('conv0', NF, activation=tf.nn.leaky_relu)
.Conv2D('conv1', NF * 2) .Conv2D('conv1', NF * 2)
.Conv2D('conv2', NF * 4) .Conv2D('conv2', NF * 4)
.Conv2D('conv3', NF * 8) .Conv2D('conv3', NF * 8)
.Deconv2D('deconv0', NF * 4) .Conv2DTranspose('deconv0', NF * 4)
.Deconv2D('deconv1', NF * 2) .Conv2DTranspose('deconv1', NF * 2)
.Deconv2D('deconv2', NF * 1) .Conv2DTranspose('deconv2', NF * 1)
.Deconv2D('deconv3', 3, nl=tf.identity) .Conv2DTranspose('deconv3', 3, activation=tf.identity)
.tf.sigmoid()()) .tf.sigmoid()())
return l return l
@auto_reuse_variable_scope @auto_reuse_variable_scope
def discriminator(self, img): def discriminator(self, img):
with argscope(Conv2D, nl=BNLReLU, kernel_shape=4, stride=2): with argscope(Conv2D, activation=BNLReLU, kernel_size=4, strides=2):
l = Conv2D('conv0', img, NF, nl=tf.nn.leaky_relu) l = Conv2D('conv0', img, NF, activation=tf.nn.leaky_relu)
relu1 = Conv2D('conv1', l, NF * 2) relu1 = Conv2D('conv1', l, NF * 2)
relu2 = Conv2D('conv2', relu1, NF * 4) relu2 = Conv2D('conv2', relu1, NF * 4)
relu3 = Conv2D('conv3', relu2, NF * 8) relu3 = Conv2D('conv3', relu2, NF * 8)
logits = FullyConnected('fc', relu3, 1, nl=tf.identity) logits = FullyConnected('fc', relu3, 1, activation=tf.identity)
return logits, [relu1, relu2, relu3] return logits, [relu1, relu2, relu3]
def get_feature_match_loss(self, feats_real, feats_fake): def get_feature_match_loss(self, feats_real, feats_fake):
...@@ -91,11 +84,11 @@ class Model(GANModelDesc): ...@@ -91,11 +84,11 @@ class Model(GANModelDesc):
B = tf.transpose(B / 255.0, [0, 3, 1, 2]) B = tf.transpose(B / 255.0, [0, 3, 1, 2])
# use the torch initializers # use the torch initializers
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Conv2DTranspose, FullyConnected],
W_init=tf.variance_scaling_initializer(scale=0.333, distribution='uniform'), kernel_initializer=tf.variance_scaling_initializer(scale=0.333, distribution='uniform'),
use_bias=False), \ use_bias=False), \
argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \ argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \
argscope([Conv2D, Deconv2D, BatchNorm], data_format='NCHW'): argscope([Conv2D, Conv2DTranspose, BatchNorm], data_format='NCHW'):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
with tf.variable_scope('B'): with tf.variable_scope('B'):
AB = self.generator(A) AB = self.generator(A)
...@@ -194,7 +187,7 @@ def get_celebA_data(datadir, styleA, styleB=None): ...@@ -194,7 +187,7 @@ def get_celebA_data(datadir, styleA, styleB=None):
imgaug.Resize(64)] imgaug.Resize(64)]
df = AugmentImageComponents(df, augs, (0, 1)) df = AugmentImageComponents(df, augs, (0, 1))
df = BatchData(df, BATCH) df = BatchData(df, BATCH)
df = PrefetchDataZMQ(df, 1) df = PrefetchDataZMQ(df, 3)
return df return df
......
...@@ -74,54 +74,54 @@ class Model(GANModelDesc): ...@@ -74,54 +74,54 @@ class Model(GANModelDesc):
with argscope(BatchNorm, use_local_stat=True), \ with argscope(BatchNorm, use_local_stat=True), \
argscope(Dropout, is_training=True): argscope(Dropout, is_training=True):
# always use local stat for BN, and apply dropout even in testing # always use local stat for BN, and apply dropout even in testing
with argscope(Conv2D, kernel_shape=4, stride=2, nl=BNLReLU): with argscope(Conv2D, kernel_size=4, strides=2, activation=BNLReLU):
e1 = Conv2D('conv1', imgs, NF, nl=tf.nn.leaky_relu) e1 = Conv2D('conv1', imgs, NF, activation=tf.nn.leaky_relu)
e2 = Conv2D('conv2', e1, NF * 2) e2 = Conv2D('conv2', e1, NF * 2)
e3 = Conv2D('conv3', e2, NF * 4) e3 = Conv2D('conv3', e2, NF * 4)
e4 = Conv2D('conv4', e3, NF * 8) e4 = Conv2D('conv4', e3, NF * 8)
e5 = Conv2D('conv5', e4, NF * 8) e5 = Conv2D('conv5', e4, NF * 8)
e6 = Conv2D('conv6', e5, NF * 8) e6 = Conv2D('conv6', e5, NF * 8)
e7 = Conv2D('conv7', e6, NF * 8) e7 = Conv2D('conv7', e6, NF * 8)
e8 = Conv2D('conv8', e7, NF * 8, nl=BNReLU) # 1x1 e8 = Conv2D('conv8', e7, NF * 8, activation=BNReLU) # 1x1
with argscope(Deconv2D, nl=BNReLU, kernel_shape=4, stride=2): with argscope(Conv2DTranspose, activation=BNReLU, kernel_size=4, strides=2):
return (LinearWrap(e8) return (LinearWrap(e8)
.Deconv2D('deconv1', NF * 8) .Conv2DTranspose('deconv1', NF * 8)
.Dropout() .Dropout()
.ConcatWith(e7, 3) .ConcatWith(e7, 3)
.Deconv2D('deconv2', NF * 8) .Conv2DTranspose('deconv2', NF * 8)
.Dropout() .Dropout()
.ConcatWith(e6, 3) .ConcatWith(e6, 3)
.Deconv2D('deconv3', NF * 8) .Conv2DTranspose('deconv3', NF * 8)
.Dropout() .Dropout()
.ConcatWith(e5, 3) .ConcatWith(e5, 3)
.Deconv2D('deconv4', NF * 8) .Conv2DTranspose('deconv4', NF * 8)
.ConcatWith(e4, 3) .ConcatWith(e4, 3)
.Deconv2D('deconv5', NF * 4) .Conv2DTranspose('deconv5', NF * 4)
.ConcatWith(e3, 3) .ConcatWith(e3, 3)
.Deconv2D('deconv6', NF * 2) .Conv2DTranspose('deconv6', NF * 2)
.ConcatWith(e2, 3) .ConcatWith(e2, 3)
.Deconv2D('deconv7', NF * 1) .Conv2DTranspose('deconv7', NF * 1)
.ConcatWith(e1, 3) .ConcatWith(e1, 3)
.Deconv2D('deconv8', OUT_CH, nl=tf.tanh)()) .Conv2DTranspose('deconv8', OUT_CH, activation=tf.tanh)())
@auto_reuse_variable_scope @auto_reuse_variable_scope
def discriminator(self, inputs, outputs): def discriminator(self, inputs, outputs):
""" return a (b, 1) logits""" """ return a (b, 1) logits"""
l = tf.concat([inputs, outputs], 3) l = tf.concat([inputs, outputs], 3)
with argscope(Conv2D, kernel_shape=4, stride=2, nl=BNLReLU): with argscope(Conv2D, kernel_size=4, strides=2, activation=BNLReLU):
l = (LinearWrap(l) l = (LinearWrap(l)
.Conv2D('conv0', NF, nl=tf.nn.leaky_relu) .Conv2D('conv0', NF, activation=tf.nn.leaky_relu)
.Conv2D('conv1', NF * 2) .Conv2D('conv1', NF * 2)
.Conv2D('conv2', NF * 4) .Conv2D('conv2', NF * 4)
.Conv2D('conv3', NF * 8, stride=1, padding='VALID') .Conv2D('conv3', NF * 8, strides=1, padding='VALID')
.Conv2D('convlast', 1, stride=1, padding='VALID', nl=tf.identity)()) .Conv2D('convlast', 1, strides=1, padding='VALID', activation=tf.identity)())
return l return l
def _build_graph(self, inputs): def _build_graph(self, inputs):
input, output = inputs input, output = inputs
input, output = input / 128.0 - 1, output / 128.0 - 1 input, output = input / 128.0 - 1, output / 128.0 - 1
with argscope([Conv2D, Deconv2D], W_init=tf.truncated_normal_initializer(stddev=0.02)): with argscope([Conv2D, Conv2DTranspose], kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
fake_output = self.generator(input) fake_output = self.generator(input)
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
......
...@@ -26,9 +26,9 @@ class Model(DCGAN.Model): ...@@ -26,9 +26,9 @@ class Model(DCGAN.Model):
@auto_reuse_variable_scope @auto_reuse_variable_scope
def discriminator(self, imgs): def discriminator(self, imgs):
nf = 64 nf = 64
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2): with argscope(Conv2D, activation=tf.identity, kernel_size=4, strides=2):
l = (LinearWrap(imgs) l = (LinearWrap(imgs)
.Conv2D('conv0', nf, nl=tf.nn.leaky_relu) .Conv2D('conv0', nf, activation=tf.nn.leaky_relu)
.Conv2D('conv1', nf * 2) .Conv2D('conv1', nf * 2)
.LayerNorm('ln1') .LayerNorm('ln1')
.tf.nn.leaky_relu() .tf.nn.leaky_relu()
...@@ -38,7 +38,7 @@ class Model(DCGAN.Model): ...@@ -38,7 +38,7 @@ class Model(DCGAN.Model):
.Conv2D('conv3', nf * 8) .Conv2D('conv3', nf * 8)
.LayerNorm('ln3') .LayerNorm('ln3')
.tf.nn.leaky_relu() .tf.nn.leaky_relu()
.FullyConnected('fct', 1, nl=tf.identity)()) .FullyConnected('fct', 1, activation=tf.identity)())
return tf.reshape(l, [-1]) return tf.reshape(l, [-1])
def _build_graph(self, inputs): def _build_graph(self, inputs):
...@@ -48,8 +48,8 @@ class Model(DCGAN.Model): ...@@ -48,8 +48,8 @@ class Model(DCGAN.Model):
z = tf.random_normal([self.batch, self.zdim], name='z_train') z = tf.random_normal([self.batch, self.zdim], name='z_train')
z = tf.placeholder_with_default(z, [None, self.zdim], name='z') z = tf.placeholder_with_default(z, [None, self.zdim], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Conv2DTranspose, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
image_gen = self.generator(z) image_gen = self.generator(z)
tf.summary.image('generated-samples', image_gen, max_outputs=30) tf.summary.image('generated-samples', image_gen, max_outputs=30)
......
...@@ -109,33 +109,33 @@ class Model(GANModelDesc): ...@@ -109,33 +109,33 @@ class Model(GANModelDesc):
return [InputDesc(tf.float32, (None, 28, 28), 'input')] return [InputDesc(tf.float32, (None, 28, 28), 'input')]
def generator(self, z): def generator(self, z):
l = FullyConnected('fc0', z, 1024, nl=BNReLU) l = FullyConnected('fc0', z, 1024, activation=BNReLU)
l = FullyConnected('fc1', l, 128 * 7 * 7, nl=BNReLU) l = FullyConnected('fc1', l, 128 * 7 * 7, activation=BNReLU)
l = tf.reshape(l, [-1, 7, 7, 128]) l = tf.reshape(l, [-1, 7, 7, 128])
l = Deconv2D('deconv1', l, 64, 4, 2, nl=BNReLU) l = Conv2DTranspose('deconv1', l, 64, 4, 2, activation=BNReLU)
l = Deconv2D('deconv2', l, 1, 4, 2, nl=tf.identity) l = Conv2DTranspose('deconv2', l, 1, 4, 2, activation=tf.identity)
l = tf.sigmoid(l, name='gen') l = tf.sigmoid(l, name='gen')
return l return l
@auto_reuse_variable_scope @auto_reuse_variable_scope
def discriminator(self, imgs): def discriminator(self, imgs):
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2): with argscope(Conv2D, kernel_size=4, strides=2):
l = (LinearWrap(imgs) l = (LinearWrap(imgs)
.Conv2D('conv0', 64) .Conv2D('conv0', 64)
.tf.nn.leaky_relu() .tf.nn.leaky_relu()
.Conv2D('conv1', 128) .Conv2D('conv1', 128)
.BatchNorm('bn1') .BatchNorm('bn1')
.tf.nn.leaky_relu() .tf.nn.leaky_relu()
.FullyConnected('fc1', 1024, nl=tf.identity) .FullyConnected('fc1', 1024)
.BatchNorm('bn2') .BatchNorm('bn2')
.tf.nn.leaky_relu()()) .tf.nn.leaky_relu()())
logits = FullyConnected('fct', l, 1, nl=tf.identity) logits = FullyConnected('fct', l, 1)
encoder = (LinearWrap(l) encoder = (LinearWrap(l)
.FullyConnected('fce1', 128, nl=tf.identity) .FullyConnected('fce1', 128)
.BatchNorm('bne') .BatchNorm('bne')
.tf.nn.leaky_relu() .tf.nn.leaky_relu()
.FullyConnected('fce-out', DIST_PARAM_DIM, nl=tf.identity)()) .FullyConnected('fce-out', DIST_PARAM_DIM)())
return logits, encoder return logits, encoder
def _build_graph(self, inputs): def _build_graph(self, inputs):
...@@ -148,8 +148,8 @@ class Model(GANModelDesc): ...@@ -148,8 +148,8 @@ class Model(GANModelDesc):
tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise') tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise')
z = tf.concat([zc, z_noise], 1, name='z') z = tf.concat([zc, z_noise], 1, name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Conv2DTranspose, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
fake_sample = self.generator(z) fake_sample = self.generator(z)
fake_sample_viz = tf.cast((fake_sample) * 255.0, tf.uint8, name='viz') fake_sample_viz = tf.cast((fake_sample) * 255.0, tf.uint8, name='viz')
......
...@@ -9,7 +9,7 @@ from ..tfutils.common import get_tf_version_number ...@@ -9,7 +9,7 @@ from ..tfutils.common import get_tf_version_number
from ..utils.argtools import shape2d, shape4d, get_data_format from ..utils.argtools import shape2d, shape4d, get_data_format
from .tflayer import rename_get_variable, convert_to_tflayer_args from .tflayer import rename_get_variable, convert_to_tflayer_args
__all__ = ['Conv2D', 'Deconv2D'] __all__ = ['Conv2D', 'Deconv2D', 'Conv2DTranspose']
@layer_register(log_shape=True) @layer_register(log_shape=True)
...@@ -125,7 +125,7 @@ def Conv2D( ...@@ -125,7 +125,7 @@ def Conv2D(
'kernel_shape': 'kernel_size', 'kernel_shape': 'kernel_size',
'stride': 'strides', 'stride': 'strides',
}) })
def Deconv2D( def Conv2DTranspose(
inputs, inputs,
filters, filters,
kernel_size, kernel_size,
...@@ -172,3 +172,6 @@ def Deconv2D( ...@@ -172,3 +172,6 @@ def Deconv2D(
if use_bias: if use_bias:
ret.variables.b = layer.bias ret.variables.b = layer.bias
return tf.identity(ret, name='output') return tf.identity(ret, name='output')
Deconv2D = Conv2DTranspose
...@@ -28,7 +28,7 @@ def FullyConnected( ...@@ -28,7 +28,7 @@ def FullyConnected(
activity_regularizer=None): activity_regularizer=None):
""" """
A wrapper around `tf.layers.Dense`. A wrapper around `tf.layers.Dense`.
One differences to maintain backward-compatibility: One difference to maintain backward-compatibility:
Default weight initializer is variance_scaling_initializer(2.0). Default weight initializer is variance_scaling_initializer(2.0).
Variable Names: Variable Names:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment