Commit 7d40e049 authored by Yuxin Wu's avatar Yuxin Wu

Write InfoGAN with tf.distributions and deprecate tfutils.distributions (fix #348)

parent 560bc84e
...@@ -11,12 +11,11 @@ import sys ...@@ -11,12 +11,11 @@ import sys
import argparse import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils import viz
from tensorpack.tfutils.distributions import * from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils import optimizer, summary from tensorpack.tfutils import optimizer, summary
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient from tensorpack.tfutils.gradproc import ScaleGradient
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from GAN import GANTrainer, GANModelDesc from GAN import GANTrainer, GANModelDesc
...@@ -31,17 +30,51 @@ A pretrained model is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU ...@@ -31,17 +30,51 @@ A pretrained model is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU
""" """
BATCH = 128 BATCH = 128
# latent space is cat(10) x uni(2) x noise(NOISE_DIM)
NUM_CLASS = 10
NUM_UNIFORM = 2
DIST_PARAM_DIM = NUM_CLASS + NUM_UNIFORM
NOISE_DIM = 62 NOISE_DIM = 62
# prior: the assumption how the latent factors are presented in the dataset
DIST_PRIOR_PARAM = [1.] * NUM_CLASS + [0.] * NUM_UNIFORM
class GaussianWithUniformSample(GaussianDistribution): def get_distributions(vec_cat, vec_uniform):
cat = tf.distributions.Categorical(logits=vec_cat, validate_args=True, name='cat')
uni = tf.distributions.Normal(vec_uniform, scale=1., validate_args=True, allow_nan_stats=False, name='uni_a')
return cat, uni
def entropy_from_samples(samples, vec):
"""
Estimate H(x|s) ~= -E_{x \sim P(x|s)}[\log Q(x|s)], where x are samples, and Q is parameterized by vec.
"""
samples_cat = tf.argmax(samples[:, :NUM_CLASS], axis=1, output_type=tf.int32)
samples_uniform = samples[:, NUM_CLASS:]
cat, uniform = get_distributions(vec[:, :NUM_CLASS], vec[:, NUM_CLASS:])
def neg_logprob(dist, sample, name):
nll = -dist.log_prob(sample)
# average over batch
return tf.reduce_sum(tf.reduce_mean(nll, axis=0), name=name)
entropies = [neg_logprob(cat, samples_cat, 'nll_cat'),
neg_logprob(uniform, samples_uniform, 'nll_uniform')]
return entropies
@under_name_scope()
def sample_prior(batch_size):
cat, _ = get_distributions(DIST_PRIOR_PARAM[:NUM_CLASS], DIST_PRIOR_PARAM[NUM_CLASS:])
sample_cat = tf.one_hot(cat.sample(batch_size), NUM_CLASS)
""" """
OpenAI official code actually models the "uniform" latent code as OpenAI official code actually models the "uniform" latent code as
a Gaussian distribution, but obtain the samples from a uniform distribution. a Gaussian distribution, but obtain the samples from a uniform distribution.
We follow the official code for now.
""" """
def _sample(self, batch_size, theta): sample_uni = tf.random_uniform([batch_size, NUM_UNIFORM], -1, 1)
return tf.random_uniform([batch_size, self.dim], -1, 1) samples = tf.concat([sample_cat, sample_uni], axis=1)
return samples
class Model(GANModelDesc): class Model(GANModelDesc):
...@@ -73,24 +106,15 @@ class Model(GANModelDesc): ...@@ -73,24 +106,15 @@ class Model(GANModelDesc):
encoder = (LinearWrap(l) encoder = (LinearWrap(l)
.FullyConnected('fce1', 128, nl=tf.identity) .FullyConnected('fce1', 128, nl=tf.identity)
.BatchNorm('bne').LeakyReLU() .BatchNorm('bne').LeakyReLU()
.FullyConnected('fce-out', self.factors.param_dim, nl=tf.identity)()) .FullyConnected('fce-out', DIST_PARAM_DIM, nl=tf.identity)())
return logits, encoder return logits, encoder
def _build_graph(self, inputs): def _build_graph(self, inputs):
real_sample = inputs[0] real_sample = inputs[0]
real_sample = tf.expand_dims(real_sample, -1) real_sample = tf.expand_dims(real_sample, -1)
# latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10),
GaussianWithUniformSample("uni_a", 1),
GaussianWithUniformSample("uni_b", 1)])
# prior: the assumption how the factors are presented in the dataset
prior = tf.constant([0.1] * 10 + [0, 0], tf.float32, [12], name='prior')
batch_prior = tf.tile(tf.expand_dims(prior, 0), [BATCH, 1], name='batch_prior')
# sample the latent code: # sample the latent code:
zc = symbf.shapeless_placeholder( zc = symbf.shapeless_placeholder(sample_prior(BATCH), 0, name='z_code')
self.factors.sample(BATCH, prior), 0, name='z_code')
z_noise = symbf.shapeless_placeholder( z_noise = symbf.shapeless_placeholder(
tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise') tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise')
z = tf.concat([zc, z_noise], 1, name='z') z = tf.concat([zc, z_noise], 1, name='z')
...@@ -115,28 +139,29 @@ class Model(GANModelDesc): ...@@ -115,28 +139,29 @@ class Model(GANModelDesc):
= H(x) + E[\log P(x|s)] = H(x) + E[\log P(x|s)]
The distribution from which zc is sampled, in this case, is set to a fixed prior already. The distribution from which zc is sampled, in this case, is set to a fixed prior already.
So the first term is a constant.
For the second term, we can maximize its variational lower bound: For the second term, we can maximize its variational lower bound:
E_{x \sim P(x|s)}[\log Q(x|s)] E_{x \sim P(x|s)}[\log Q(x|s)]
where Q(x|s) is a proposal distribution to approximate P(x|s). where Q(x|s) is a proposal distribution to approximate P(x|s).
Here, Q(x|s) is assumed to be a distribution which shares the form Here, Q(x|s) is assumed to be a distribution which shares the form
of self.factors, and whose parameters are predicted by the discriminator network. of P, and whose parameters are predicted by the discriminator network.
""" """
with tf.name_scope("mutual_information"): with tf.name_scope("mutual_information"):
ents = self.factors.entropy(zc, batch_prior) batch_prior = tf.tile(tf.expand_dims(DIST_PRIOR_PARAM, 0), [BATCH, 1], name='batch_prior')
with tf.name_scope('prior_entropy'):
cat, uni = get_distributions(DIST_PRIOR_PARAM[:NUM_CLASS], DIST_PRIOR_PARAM[NUM_CLASS:])
ents = [cat.entropy(name='cat_entropy'), tf.reduce_sum(uni.entropy(), name='uni_entropy')]
entropy = tf.add_n(ents, name='total_entropy') entropy = tf.add_n(ents, name='total_entropy')
# Note that dropping this term has no effect because the entropy # Note that the entropy of prior is a constant. The paper mentioned it but didn't use it.
# of prior is a constant. The paper mentioned it but didn't use it.
# Adding this term may make the curve less stable because the
# entropy estimated from the samples is not the true value.
# post-process output vector from discriminator to obtain valid distribution parameters with tf.name_scope('conditional_entropy'):
encoder_activation = self.factors.encoder_activation(dist_param) cond_ents = entropy_from_samples(zc, dist_param)
cond_ents = self.factors.entropy(zc, encoder_activation) cond_entropy = tf.add_n(cond_ents, name="total_entropy")
cond_entropy = tf.add_n(cond_ents, name="total_conditional_entropy")
MI = tf.subtract(entropy, cond_entropy, name='mutual_information') MI = tf.subtract(entropy, cond_entropy, name='mutual_information')
summary.add_moving_summary(entropy, cond_entropy, MI, *ents) summary.add_moving_summary(entropy, cond_entropy, MI, *cond_ents)
# default GAN objective # default GAN objective
self.build_losses(real_pred, fake_pred) self.build_losses(real_pred, fake_pred)
...@@ -151,7 +176,7 @@ class Model(GANModelDesc): ...@@ -151,7 +176,7 @@ class Model(GANModelDesc):
self.collect_variables() self.collect_variables()
def _get_optimizer(self): def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True) lr = tf.get_variable('learning_rate', initializer=2e-4, dtype=tf.float32, trainable=False)
opt = tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6) opt = tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6)
# generator learns 5 times faster # generator learns 5 times faster
return optimizer.apply_grad_processors( return optimizer.apply_grad_processors(
...@@ -165,7 +190,7 @@ def get_data(): ...@@ -165,7 +190,7 @@ def get_data():
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir('d')
return TrainConfig( return TrainConfig(
dataflow=get_data(), dataflow=get_data(),
callbacks=[ModelSaver(keep_freq=0.1)], callbacks=[ModelSaver(keep_freq=0.1)],
...@@ -195,26 +220,26 @@ def sample(model_path): ...@@ -195,26 +220,26 @@ def sample(model_path):
z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM)) z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM))
zc = np.concatenate((z_cat, z_uni * 0, z_uni * 0), axis=1) zc = np.concatenate((z_cat, z_uni * 0, z_uni * 0), axis=1)
o = pred(zc, z_noise)[0] o = pred(zc, z_noise)[0]
viz1 = stack_patches(o, nr_row=10, nr_col=10) viz1 = viz.stack_patches(o, nr_row=10, nr_col=10)
viz1 = cv2.resize(viz1, (IMG_SIZE, IMG_SIZE)) viz1 = cv2.resize(viz1, (IMG_SIZE, IMG_SIZE))
# show effect of first continous variable with fixed noise # show effect of first continous variable with fixed noise
zc = np.concatenate((z_cat, z_uni, z_uni * 0), axis=1) zc = np.concatenate((z_cat, z_uni, z_uni * 0), axis=1)
o = pred(zc, z_noise * 0)[0] o = pred(zc, z_noise * 0)[0]
viz2 = stack_patches(o, nr_row=10, nr_col=10) viz2 = viz.stack_patches(o, nr_row=10, nr_col=10)
viz2 = cv2.resize(viz2, (IMG_SIZE, IMG_SIZE)) viz2 = cv2.resize(viz2, (IMG_SIZE, IMG_SIZE))
# show effect of second continous variable with fixed noise # show effect of second continous variable with fixed noise
zc = np.concatenate((z_cat, z_uni * 0, z_uni), axis=1) zc = np.concatenate((z_cat, z_uni * 0, z_uni), axis=1)
o = pred(zc, z_noise * 0)[0] o = pred(zc, z_noise * 0)[0]
viz3 = stack_patches(o, nr_row=10, nr_col=10) viz3 = viz.stack_patches(o, nr_row=10, nr_col=10)
viz3 = cv2.resize(viz3, (IMG_SIZE, IMG_SIZE)) viz3 = cv2.resize(viz3, (IMG_SIZE, IMG_SIZE))
viz = stack_patches( canvas = viz.stack_patches(
[viz1, viz2, viz3], [viz1, viz2, viz3],
nr_row=1, nr_col=3, border=5, bgcolor=(255, 0, 0)) nr_row=1, nr_col=3, border=5, bgcolor=(255, 0, 0))
interactive_imshow(viz) viz.interactive_imshow(canvas)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2,6 +2,7 @@ import tensorflow as tf ...@@ -2,6 +2,7 @@ import tensorflow as tf
from functools import wraps from functools import wraps
import numpy as np import numpy as np
from ..utils.develop import log_deprecated
from .common import get_tf_version_number from .common import get_tf_version_number
__all__ = ['Distribution', __all__ = ['Distribution',
...@@ -59,6 +60,7 @@ class Distribution(object): ...@@ -59,6 +60,7 @@ class Distribution(object):
distribution. distribution.
""" """
self.name = name self.name = name
log_deprecated("tfutils.distributions", "Please use tf.distributions instead!", "2017-12-10")
@class_scope @class_scope
def loglikelihood(self, x, theta): def loglikelihood(self, x, theta):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment