Commit 54dc36ba authored by Yuxin Wu's avatar Yuxin Wu

improved wgan

parent 745c70a4
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: Improved-WGAN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import argparse
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from GAN import SeparateGANTrainer
"""
Improved Wasserstein-GAN.
See the docstring in DCGAN.py for usage.
"""
# Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN, and change the batch size & model
import DCGAN
G.BATCH = 64
G.Z_DIM = 128
class Model(DCGAN.Model):
# replace BatchNorm by LayerNorm
@auto_reuse_variable_scope
def discriminator(self, imgs):
nf = 64
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \
argscope(LeakyReLU, alpha=0.2):
l = (LinearWrap(imgs)
.Conv2D('conv0', nf, nl=LeakyReLU)
.Conv2D('conv1', nf * 2)
.LayerNorm('ln1').LeakyReLU()
.Conv2D('conv2', nf * 4)
.LayerNorm('ln2').LeakyReLU()
.Conv2D('conv3', nf * 8)
.LayerNorm('ln3').LeakyReLU()
.FullyConnected('fct', 1, nl=tf.identity)())
return tf.reshape(l, [-1])
def _build_graph(self, inputs):
image_pos = inputs[0]
image_pos = image_pos / 128.0 - 1
z = tf.random_normal([G.BATCH, G.Z_DIM], name='z_train')
z = tf.placeholder_with_default(z, [None, G.Z_DIM], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'):
image_gen = self.generator(z)
tf.summary.image('generated-samples', image_gen, max_outputs=30)
alpha = tf.random_uniform(shape=[G.BATCH, 1, 1, 1],
minval=0., maxval=1., name='alpha')
interp = image_pos + alpha * (image_gen - image_pos)
with tf.variable_scope('discrim'):
vecpos = self.discriminator(image_pos)
vecneg = self.discriminator(image_gen)
vec_interp = self.discriminator(interp)
# the Wasserstein-GAN losses
self.d_loss = tf.reduce_mean(vecneg - vecpos, name='d_loss')
self.g_loss = tf.negative(tf.reduce_mean(vecneg), name='g_loss')
gradients = tf.gradients(vec_interp, [interp])[0]
gradients = tf.sqrt(tf.reduce_sum(tf.square(gradients), [1, 2, 3]))
gradients_rms = symbolic_functions.rms(gradients, 'gradient_rms')
gradient_penalty = tf.reduce_mean(tf.square(gradients - 1), name='gradient_penalty')
add_moving_summary(self.d_loss, self.g_loss, gradient_penalty, gradients_rms)
self.d_loss = tf.add(self.d_loss, 10 * gradient_penalty)
self.collect_variables()
def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True)
opt = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9)
return opt
DCGAN.Model = Model
if __name__ == '__main__':
args = DCGAN.get_args()
if args.sample:
DCGAN.sample(args.load)
else:
assert args.data
logger.auto_set_dir()
config = DCGAN.get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SeparateGANTrainer(config, g_period=6).train()
...@@ -12,11 +12,13 @@ Reproduce the following GAN-related methods: ...@@ -12,11 +12,13 @@ Reproduce the following GAN-related methods:
+ [Wasserstein GAN](https://arxiv.org/abs/1701.07875) + [Wasserstein GAN](https://arxiv.org/abs/1701.07875)
+ Improved Wasserstein GAN ([Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028))
+ DiscoGAN ([Learning to Discover Cross-Domain Relations with Generative Adversarial Networks](https://arxiv.org/abs/1703.05192)) + DiscoGAN ([Learning to Discover Cross-Domain Relations with Generative Adversarial Networks](https://arxiv.org/abs/1703.05192))
Please see the __docstring__ in each script for detailed usage and pretrained models. Please see the __docstring__ in each script for detailed usage and pretrained models.
## DCGAN-CelebA.py ## DCGAN.py
Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch). Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch).
...@@ -54,9 +56,9 @@ It then maximizes mutual information between these latent variables and the imag ...@@ -54,9 +56,9 @@ It then maximizes mutual information between these latent variables and the imag
Train a simple GAN on mnist, conditioned on the class labels. Train a simple GAN on mnist, conditioned on the class labels.
## WGAN-CelebA.py ## WGAN.py, Improved-WGAN.py
Reproduce Wasserstein GAN by some small modifications on DCGAN-CelebA.py. Just some small modifications on top of DCGAN.py.
## DiscoGAN-CelebA.py ## DiscoGAN-CelebA.py
......
...@@ -90,6 +90,7 @@ class ProgressBar(Callback): ...@@ -90,6 +90,7 @@ class ProgressBar(Callback):
super(ProgressBar, self).__init__() super(ProgressBar, self).__init__()
self._names = [get_op_tensor_name(n)[1] for n in names] self._names = [get_op_tensor_name(n)[1] for n in names]
self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names] self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names]
self._bar = None
def _before_train(self): def _before_train(self):
self._last_updated = self.local_step self._last_updated = self.local_step
...@@ -125,4 +126,5 @@ class ProgressBar(Callback): ...@@ -125,4 +126,5 @@ class ProgressBar(Callback):
self._bar.close() self._bar.close()
def _after_train(self): def _after_train(self):
self._bar.close() if self._bar: # training may get killed before the first step
self._bar.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment