Commit 3e0a861e authored by Yuxin Wu's avatar Yuxin Wu

add a simple conditional gan

parent f41e6326
...@@ -19,6 +19,7 @@ TensorFlow itself also changes API and those are not listed here. ...@@ -19,6 +19,7 @@ TensorFlow itself also changes API and those are not listed here.
See [commit](https://github.com/ppwwyyxx/tensorpack/commit/651a5aea8f9aacad7147542021dcf106fc824bc2) to change your code. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/651a5aea8f9aacad7147542021dcf106fc824bc2) to change your code.
* 2016/12/15. The `predict_tower` option is in `TrainConfig` now instead of `Trainer`. See * 2016/12/15. The `predict_tower` option is in `TrainConfig` now instead of `Trainer`. See
[commit](https://github.com/ppwwyyxx/tensorpack/commit/99c70935a7f72050f45891fbbcc49c4ce43aedce). [commit](https://github.com/ppwwyyxx/tensorpack/commit/99c70935a7f72050f45891fbbcc49c4ce43aedce).
* 2016/11/10. The `{input,output}_var_names` argument in `PredictConfig` is renamed to `{input,output}_names`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/77bcc8b1afc984a569f6ec3eda0a3c47b4e2923a).
* 2016/11/06. The inferencer `ClassificationError` now expects the vector tensor returned by * 2016/11/06. The inferencer `ClassificationError` now expects the vector tensor returned by
`prediction_incorrect` instead of the "wrong" tensor. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/740e9d8ca146af5a911f68a369dd7348243a2253) `prediction_incorrect` instead of the "wrong" tensor. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/740e9d8ca146af5a911f68a369dd7348243a2253)
to make changes. to make changes.
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ConditionalGAN-mnist.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import tensorflow as tf
import os
import sys
import cv2
import argparse
from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, GANModelDesc
"""
To train:
./ConditionalGAN-mnist.py
To visualize:
./ConditionalGAN-mnist.py --sample --load path/to/model
"""
BATCH = 128
class Model(GANModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, (None, 28, 28), 'input'),
InputVar(tf.int32, (None,), 'label')]
def generator(self, z, y):
l = FullyConnected('fc0', tf.concat([z, y], 1), 1024, nl=BNReLU)
l = FullyConnected('fc1', tf.concat([l, y], 1), 64 * 2 * 7 * 7, nl=BNReLU)
l = tf.reshape(l, [-1, 7, 7, 64 * 2])
y = tf.reshape(y, [-1, 1, 1, 10])
l = tf.concat([l, tf.tile(y, [1, 7, 7, 1])], 3)
l = Deconv2D('deconv1', l, [14, 14, 64 * 2], 5, 2, nl=BNReLU)
l = tf.concat([l, tf.tile(y, [1, 14, 14, 1])], 3)
l = Deconv2D('deconv2', l, [28, 28, 1], 5, 2, nl=tf.identity)
l = tf.nn.tanh(l, name='gen')
return l
def discriminator(self, imgs, y):
""" return a (b, 1) logits"""
yv = y
y = tf.reshape(y, [-1, 1, 1, 10])
with argscope(Conv2D, nl=tf.identity, kernel_shape=5, stride=2), \
argscope(LeakyReLU, alpha=0.2):
l = (LinearWrap(imgs)
.ConcatWith(tf.tile(y, [1, 28, 28, 1]), 3)
.Conv2D('conv0', 11)
.LeakyReLU()
.ConcatWith(tf.tile(y, [1, 14, 14, 1]), 3)
.Conv2D('conv1', 74)
.BatchNorm('bn1').LeakyReLU()
.apply(symbf.batch_flatten)
.ConcatWith(yv, 1)
.FullyConnected('fc1', 1024, nl=tf.identity)
.BatchNorm('bn2').LeakyReLU()
.ConcatWith(yv, 1)
.FullyConnected('fct', 1, nl=tf.identity)())
return l
def _build_graph(self, input_vars):
image_pos, y = input_vars
image_pos = tf.expand_dims(image_pos * 2.0 - 1, -1)
y = tf.one_hot(y, 10, name='label_onehot')
z = tf.random_uniform([BATCH, 100], -1, 1, name='z_train')
z = symbf.shapeless_placeholder(z, [0], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'):
image_gen = self.generator(z, y)
tf.summary.image('gen', image_gen, 30)
with tf.variable_scope('discrim'):
vecpos = self.discriminator(image_pos, y)
with tf.variable_scope('discrim', reuse=True):
vecneg = self.discriminator(image_gen, y)
self.build_losses(vecpos, vecneg)
self.collect_variables()
def get_data():
ds = ConcatData([dataset.Mnist('train'), dataset.Mnist('test')])
ds = BatchData(ds, BATCH)
return ds
def get_config():
logger.auto_set_dir()
dataset = get_data()
lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig(
dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
model=Model(),
steps_per_epoch=500,
max_epoch=100,
)
def sample(model_path):
pred = PredictConfig(
session_init=get_model_loader(model_path),
model=Model(),
input_names=['label', 'z'],
output_names=['gen/gen'])
ds = MapData(RandomZData((100, 100)),
lambda dp: [np.arange(100) % 10, dp[0]])
pred = SimpleDatasetPredictor(pred, ds)
for o in pred.get_result():
o = o[0] * 255.0
viz = next(build_patch_list(o, nr_row=10, nr_col=10))
viz = cv2.resize(viz, (800, 800))
interactive_imshow(viz)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample:
sample(args.load)
else:
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(config).train()
...@@ -27,7 +27,7 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference ...@@ -27,7 +27,7 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
2. Start training: 2. Start training:
./DCGAN-CelebA.py --data /path/to/image_align_celeba/ ./DCGAN-CelebA.py --data /path/to/image_align_celeba/
3. Visualize samples of a trained model: 3. Visualize samples of a trained model:
./DCGAN-CelebA.py --load model.tfmodel --sample ./DCGAN-CelebA.py --load path/to/model --sample
""" """
CFG.SHAPE = 64 CFG.SHAPE = 64
......
...@@ -17,6 +17,14 @@ import tensorpack.tfutils.symbolic_functions as symbf ...@@ -17,6 +17,14 @@ import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient
from GAN import GANTrainer, GANModelDesc from GAN import GANTrainer, GANModelDesc
"""
To train:
./InfoGAN-mnist.py
To visualize:
./InfoGAN-mnist.py --sample --load path/to/model
"""
BATCH = 128 BATCH = 128
NOISE_DIM = 62 NOISE_DIM = 62
......
# Generative Adversarial Networks # Generative Adversarial Networks
Reproduce the following GAN-related papers: Reproduce the following GAN-related methods:
+ Unsupervised Representation Learning with DCGAN. [paper](https://arxiv.org/abs/1511.06434) + DCGAN ([Unsupervised Representation Learning with DCGAN](https://arxiv.org/abs/1511.06434))
+ Image-to-image Translation with Conditional Adversarial Networks. [paper](https://arxiv.org/pdf/1611.07004v1.pdf) + pix2pix ([Image-to-image Translation with Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf))
+ InfoGAN: Interpretable Representation Learning by Information Maximizing GAN. [paper](https://arxiv.org/abs/1606.03657) + InfoGAN ([InfoGAN: Interpretable Representation Learning by Information Maximizing GAN](https://arxiv.org/abs/1606.03657))
+ Conditional GAN
Please see the __docstring__ in each script for detailed usage. Please see the __docstring__ in each script for detailed usage.
...@@ -46,3 +48,6 @@ It then maximizes mutual information between these latent variables and the imag ...@@ -46,3 +48,6 @@ It then maximizes mutual information between these latent variables and the imag
* Middle: 1 continuous latent variable controlled the rotation. * Middle: 1 continuous latent variable controlled the rotation.
* Right: another continuous latent variable controlled the thickness. * Right: another continuous latent variable controlled the thickness.
## ConditionalGAN-mnist.py
Train a simple GAN on mnist, conditioned on the class labels.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment