Commit 7d40e049 authored by Yuxin Wu's avatar Yuxin Wu

Write InfoGAN with tf.distributions and deprecate tfutils.distributions (fix #348)

parent 560bc84e
......@@ -11,12 +11,11 @@ import sys
import argparse
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.distributions import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils import viz
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.tfutils import optimizer, summary
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient
from tensorpack.tfutils.gradproc import ScaleGradient
from tensorpack.dataflow import dataset
from GAN import GANTrainer, GANModelDesc
......@@ -31,17 +30,51 @@ A pretrained model is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU
"""
BATCH = 128
# latent space is cat(10) x uni(2) x noise(NOISE_DIM)
NUM_CLASS = 10
NUM_UNIFORM = 2
DIST_PARAM_DIM = NUM_CLASS + NUM_UNIFORM
NOISE_DIM = 62
# prior: the assumption how the latent factors are presented in the dataset
DIST_PRIOR_PARAM = [1.] * NUM_CLASS + [0.] * NUM_UNIFORM
class GaussianWithUniformSample(GaussianDistribution):
def get_distributions(vec_cat, vec_uniform):
cat = tf.distributions.Categorical(logits=vec_cat, validate_args=True, name='cat')
uni = tf.distributions.Normal(vec_uniform, scale=1., validate_args=True, allow_nan_stats=False, name='uni_a')
return cat, uni
def entropy_from_samples(samples, vec):
"""
Estimate H(x|s) ~= -E_{x \sim P(x|s)}[\log Q(x|s)], where x are samples, and Q is parameterized by vec.
"""
samples_cat = tf.argmax(samples[:, :NUM_CLASS], axis=1, output_type=tf.int32)
samples_uniform = samples[:, NUM_CLASS:]
cat, uniform = get_distributions(vec[:, :NUM_CLASS], vec[:, NUM_CLASS:])
def neg_logprob(dist, sample, name):
nll = -dist.log_prob(sample)
# average over batch
return tf.reduce_sum(tf.reduce_mean(nll, axis=0), name=name)
entropies = [neg_logprob(cat, samples_cat, 'nll_cat'),
neg_logprob(uniform, samples_uniform, 'nll_uniform')]
return entropies
@under_name_scope()
def sample_prior(batch_size):
cat, _ = get_distributions(DIST_PRIOR_PARAM[:NUM_CLASS], DIST_PRIOR_PARAM[NUM_CLASS:])
sample_cat = tf.one_hot(cat.sample(batch_size), NUM_CLASS)
"""
OpenAI official code actually models the "uniform" latent code as
a Gaussian distribution, but obtain the samples from a uniform distribution.
We follow the official code for now.
"""
def _sample(self, batch_size, theta):
return tf.random_uniform([batch_size, self.dim], -1, 1)
sample_uni = tf.random_uniform([batch_size, NUM_UNIFORM], -1, 1)
samples = tf.concat([sample_cat, sample_uni], axis=1)
return samples
class Model(GANModelDesc):
......@@ -73,24 +106,15 @@ class Model(GANModelDesc):
encoder = (LinearWrap(l)
.FullyConnected('fce1', 128, nl=tf.identity)
.BatchNorm('bne').LeakyReLU()
.FullyConnected('fce-out', self.factors.param_dim, nl=tf.identity)())
.FullyConnected('fce-out', DIST_PARAM_DIM, nl=tf.identity)())
return logits, encoder
def _build_graph(self, inputs):
real_sample = inputs[0]
real_sample = tf.expand_dims(real_sample, -1)
# latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10),
GaussianWithUniformSample("uni_a", 1),
GaussianWithUniformSample("uni_b", 1)])
# prior: the assumption how the factors are presented in the dataset
prior = tf.constant([0.1] * 10 + [0, 0], tf.float32, [12], name='prior')
batch_prior = tf.tile(tf.expand_dims(prior, 0), [BATCH, 1], name='batch_prior')
# sample the latent code:
zc = symbf.shapeless_placeholder(
self.factors.sample(BATCH, prior), 0, name='z_code')
zc = symbf.shapeless_placeholder(sample_prior(BATCH), 0, name='z_code')
z_noise = symbf.shapeless_placeholder(
tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise')
z = tf.concat([zc, z_noise], 1, name='z')
......@@ -115,28 +139,29 @@ class Model(GANModelDesc):
= H(x) + E[\log P(x|s)]
The distribution from which zc is sampled, in this case, is set to a fixed prior already.
So the first term is a constant.
For the second term, we can maximize its variational lower bound:
E_{x \sim P(x|s)}[\log Q(x|s)]
where Q(x|s) is a proposal distribution to approximate P(x|s).
Here, Q(x|s) is assumed to be a distribution which shares the form
of self.factors, and whose parameters are predicted by the discriminator network.
of P, and whose parameters are predicted by the discriminator network.
"""
with tf.name_scope("mutual_information"):
ents = self.factors.entropy(zc, batch_prior)
entropy = tf.add_n(ents, name='total_entropy')
# Note that dropping this term has no effect because the entropy
# of prior is a constant. The paper mentioned it but didn't use it.
# Adding this term may make the curve less stable because the
# entropy estimated from the samples is not the true value.
# post-process output vector from discriminator to obtain valid distribution parameters
encoder_activation = self.factors.encoder_activation(dist_param)
cond_ents = self.factors.entropy(zc, encoder_activation)
cond_entropy = tf.add_n(cond_ents, name="total_conditional_entropy")
batch_prior = tf.tile(tf.expand_dims(DIST_PRIOR_PARAM, 0), [BATCH, 1], name='batch_prior')
with tf.name_scope('prior_entropy'):
cat, uni = get_distributions(DIST_PRIOR_PARAM[:NUM_CLASS], DIST_PRIOR_PARAM[NUM_CLASS:])
ents = [cat.entropy(name='cat_entropy'), tf.reduce_sum(uni.entropy(), name='uni_entropy')]
entropy = tf.add_n(ents, name='total_entropy')
# Note that the entropy of prior is a constant. The paper mentioned it but didn't use it.
with tf.name_scope('conditional_entropy'):
cond_ents = entropy_from_samples(zc, dist_param)
cond_entropy = tf.add_n(cond_ents, name="total_entropy")
MI = tf.subtract(entropy, cond_entropy, name='mutual_information')
summary.add_moving_summary(entropy, cond_entropy, MI, *ents)
summary.add_moving_summary(entropy, cond_entropy, MI, *cond_ents)
# default GAN objective
self.build_losses(real_pred, fake_pred)
......@@ -151,7 +176,7 @@ class Model(GANModelDesc):
self.collect_variables()
def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True)
lr = tf.get_variable('learning_rate', initializer=2e-4, dtype=tf.float32, trainable=False)
opt = tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6)
# generator learns 5 times faster
return optimizer.apply_grad_processors(
......@@ -165,7 +190,7 @@ def get_data():
def get_config():
logger.auto_set_dir()
logger.auto_set_dir('d')
return TrainConfig(
dataflow=get_data(),
callbacks=[ModelSaver(keep_freq=0.1)],
......@@ -195,26 +220,26 @@ def sample(model_path):
z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM))
zc = np.concatenate((z_cat, z_uni * 0, z_uni * 0), axis=1)
o = pred(zc, z_noise)[0]
viz1 = stack_patches(o, nr_row=10, nr_col=10)
viz1 = viz.stack_patches(o, nr_row=10, nr_col=10)
viz1 = cv2.resize(viz1, (IMG_SIZE, IMG_SIZE))
# show effect of first continous variable with fixed noise
zc = np.concatenate((z_cat, z_uni, z_uni * 0), axis=1)
o = pred(zc, z_noise * 0)[0]
viz2 = stack_patches(o, nr_row=10, nr_col=10)
viz2 = viz.stack_patches(o, nr_row=10, nr_col=10)
viz2 = cv2.resize(viz2, (IMG_SIZE, IMG_SIZE))
# show effect of second continous variable with fixed noise
zc = np.concatenate((z_cat, z_uni * 0, z_uni), axis=1)
o = pred(zc, z_noise * 0)[0]
viz3 = stack_patches(o, nr_row=10, nr_col=10)
viz3 = viz.stack_patches(o, nr_row=10, nr_col=10)
viz3 = cv2.resize(viz3, (IMG_SIZE, IMG_SIZE))
viz = stack_patches(
canvas = viz.stack_patches(
[viz1, viz2, viz3],
nr_row=1, nr_col=3, border=5, bgcolor=(255, 0, 0))
interactive_imshow(viz)
viz.interactive_imshow(canvas)
if __name__ == '__main__':
......
......@@ -2,6 +2,7 @@ import tensorflow as tf
from functools import wraps
import numpy as np
from ..utils.develop import log_deprecated
from .common import get_tf_version_number
__all__ = ['Distribution',
......@@ -59,6 +60,7 @@ class Distribution(object):
distribution.
"""
self.name = name
log_deprecated("tfutils.distributions", "Please use tf.distributions instead!", "2017-12-10")
@class_scope
def loglikelihood(self, x, theta):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment