Commit 19c6b446 authored by Yuxin Wu's avatar Yuxin Wu

implement sample() for distributions; add entropy term in InfoGAN.

parent e66857ba
...@@ -58,22 +58,30 @@ class Model(GANModelDesc): ...@@ -58,22 +58,30 @@ class Model(GANModelDesc):
real_sample = tf.expand_dims(real_sample * 2.0 - 1, -1) real_sample = tf.expand_dims(real_sample * 2.0 - 1, -1)
# latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM) # latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
# OpenAI code actually uses Gaussian distribution for uniform, except
# in the sample step. We follow the official implementation for now.
self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10), self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10),
GaussianDistribution("uni_a", 1), GaussianDistribution("uni_a", 1),
GaussianDistribution("uni_b", 1)]) GaussianDistribution("uni_b", 1)])
# prior: the assumption how the factors are presented in the dataset
prior = tf.constant([0.1] * 10 + [0,0], tf.float32, [12], name='prior')
batch_prior = tf.tile(tf.expand_dims(prior, 0), [BATCH, 1], name='batch_prior')
# sample the latent code zc: # sample the latent code zc:
idxs = tf.squeeze(tf.multinomial(tf.zeros([BATCH, 10]), 1), 1) sample = self.factors.dists[0].sample(
sample = tf.one_hot(idxs, 10) BATCH, tf.constant([0.1]*10, tf.float32, shape=[10]))
z_cat = symbf.remove_shape(sample, 0, name='z_cat') z_cat = symbf.remove_shape(sample, 0, name='z_cat')
# still sample the latent code from a uniform distribution.
z_uni_a = symbf.remove_shape( z_uni_a = symbf.remove_shape(
tf.random_uniform([BATCH, 1], -1, 1), 0, name='z_uni_a') tf.random_uniform([BATCH, 1], -1, 1), 0, name='z_uni_a')
z_uni_b = symbf.remove_shape( z_uni_b = symbf.remove_shape(
tf.random_uniform([BATCH, 1], -1, 1), 0, name='z_uni_b') tf.random_uniform([BATCH, 1], -1, 1), 0, name='z_uni_b')
zc = tf.concat_v2([z_cat, z_uni_a, z_uni_b], 1, name='z_code')
# TODO ideally this can be done by self.factors.sample, if sample
# method is consistent with the distribution
z_noise = symbf.remove_shape( z_noise = symbf.remove_shape(
tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise') tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise')
zc = tf.concat_v2([z_cat, z_uni_a, z_uni_b], 1, name='z_code')
z = tf.concat_v2([zc, z_noise], 1, name='z') z = tf.concat_v2([zc, z_noise], 1, name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Deconv2D, FullyConnected],
...@@ -110,16 +118,25 @@ class Model(GANModelDesc): ...@@ -110,16 +118,25 @@ class Model(GANModelDesc):
of self.factors, and whose parameters are predicted by the discriminator network. of self.factors, and whose parameters are predicted by the discriminator network.
""" """
with tf.name_scope("mutual_information"): with tf.name_scope("mutual_information"):
ents = self.factors.entropy(zc, encoder_activation) ents = self.factors.entropy(zc, batch_prior)
cond_entropy = tf.add_n(ents, name="total_conditional_entropy") entropy = tf.add_n(ents, name='total_entropy')
summary.add_moving_summary(cond_entropy, *ents) # Note that dropping this term has no effect because the entropy
# of prior is a constant. The paper mentioned it but didn't use it.
# Adding this term may make the curve less stable because the
# entropy estimated from the samples is not the true value.
cond_ents = self.factors.entropy(zc, encoder_activation)
cond_entropy = tf.add_n(cond_ents, name="total_conditional_entropy")
MI = tf.subtract(entropy, cond_entropy, name='mutual_information')
summary.add_moving_summary(entropy, cond_entropy, MI, *ents)
# default GAN objective # default GAN objective
self.build_losses(real_pred, fake_pred) self.build_losses(real_pred, fake_pred)
# subtract mutual information for latent factores (we want to maximize them) # subtract mutual information for latent factores (we want to maximize them)
self.g_loss = tf.add(self.g_loss, cond_entropy, name='total_g_loss') self.g_loss = tf.subtract(self.g_loss, MI, name='total_g_loss')
self.d_loss = tf.add(self.d_loss, cond_entropy, name='total_d_loss') self.d_loss = tf.subtract(self.d_loss, MI, name='total_d_loss')
summary.add_moving_summary(self.g_loss, self.d_loss) summary.add_moving_summary(self.g_loss, self.d_loss)
...@@ -127,6 +144,7 @@ class Model(GANModelDesc): ...@@ -127,6 +144,7 @@ class Model(GANModelDesc):
self.collect_variables() self.collect_variables()
def get_gradient_processor_g(self): def get_gradient_processor_g(self):
# generator learns 5 times faster
return [CheckGradient(), ScaleGradient(('.*', 5), log=False)] return [CheckGradient(), ScaleGradient(('.*', 5), log=False)]
......
...@@ -88,6 +88,32 @@ class Distribution(object): ...@@ -88,6 +88,32 @@ class Distribution(object):
""" """
return tf.reduce_mean(-self.loglikelihood(x, theta), name="entropy") return tf.reduce_mean(-self.loglikelihood(x, theta), name="entropy")
@class_scope
def sample(self, batch_size, theta):
"""
Sample a batch of vectors from this distrbution parameterized by theta.
Args:
batch_size(int): the batch size.
theta: a tensor of shape (param_dim,) or (batch, param_dim).
Returns:
a batch of samples of shape (batch, sample_dim)
"""
assert isinstance(batch_size, int), batch_size
shp = theta.get_shape()
assert shp.ndims in [1, 2] and shp[-1] == self.sample_dim, shp
if shp.ndims == 1:
theta = tf.tile(tf.expand_dims(theta, 0), [batch_size, 1],
name='tiled_theta')
else:
assert shp[0] == batch_size, shp
x = self._sample(batch_size, theta)
assert x.get_shape().ndims == 2 and \
x.get_shape()[1] == self.sample_dim, \
x.get_shape()
return x
@class_scope @class_scope
def encoder_activation(self, dist_param): def encoder_activation(self, dist_param):
""" An activation function to produce """ An activation function to produce
...@@ -107,7 +133,7 @@ class Distribution(object): ...@@ -107,7 +133,7 @@ class Distribution(object):
Returns: Returns:
int: the dimension of parameters of this distribution. int: the dimension of parameters of this distribution.
""" """
raise NotImplementedError raise NotImplementedError()
@property @property
def sample_dim(self): def sample_dim(self):
...@@ -115,14 +141,17 @@ class Distribution(object): ...@@ -115,14 +141,17 @@ class Distribution(object):
Returns: Returns:
int: the dimension of samples out of this distribution. int: the dimension of samples out of this distribution.
""" """
raise NotImplementedError raise NotImplementedError()
def _loglikelihood(self, x, theta): def _loglikelihood(self, x, theta):
raise NotImplementedError raise NotImplementedError()
def _encoder_activation(self, dist_param): def _encoder_activation(self, dist_param):
return dist_param return dist_param
def _sample(self, batch_size, theta):
raise NotImplementedError()
class CategoricalDistribution(Distribution): class CategoricalDistribution(Distribution):
""" Categorical distribution of a set of classes. """ Categorical distribution of a set of classes.
...@@ -143,6 +172,11 @@ class CategoricalDistribution(Distribution): ...@@ -143,6 +172,11 @@ class CategoricalDistribution(Distribution):
def _encoder_activation(self, dist_param): def _encoder_activation(self, dist_param):
return tf.nn.softmax(dist_param) return tf.nn.softmax(dist_param)
def _sample(self, batch_size, theta):
ids = tf.squeeze(tf.multinomial(
tf.log(theta + 1e-8), num_samples=1), 1)
return tf.one_hot(ids, self.cardinality, name='sample')
@property @property
def param_dim(self): def param_dim(self):
return self.cardinality return self.cardinality
...@@ -188,6 +222,15 @@ class GaussianDistribution(Distribution): ...@@ -188,6 +222,15 @@ class GaussianDistribution(Distribution):
stddev = tf.sqrt(tf.exp(stddev)) stddev = tf.sqrt(tf.exp(stddev))
return tf.concat_v2([mean, stddev], axis=1) return tf.concat_v2([mean, stddev], axis=1)
def _sample(self, batch_size, theta):
if self.fixed_std:
mean = theta
stddev = 1
else:
mean, stddev = tf.split(theta, 2, axis=1)
e = tf.random_normal(tf.shape(mean))
return tf.add(mean, e * stddev, name='sample')
@property @property
def param_dim(self): def param_dim(self):
if self.fixed_std: if self.fixed_std:
...@@ -257,3 +300,9 @@ class ProductDistribution(Distribution): ...@@ -257,3 +300,9 @@ class ProductDistribution(Distribution):
if dist.param_dim > 0: if dist.param_dim > 0:
rsl.append(dist._encoder_activation(dist_param)) rsl.append(dist._encoder_activation(dist_param))
return tf.concat_v2(rsl, 1) return tf.concat_v2(rsl, 1)
def _sample(self, batch_size, theta):
ret = []
for dist, ti in zip(self.dists, self._splitter(theta, True)):
ret.append(dist._sample(batch_size, ti))
return tf.concat_v2(ret, 1, name='sample')
...@@ -363,6 +363,7 @@ def soft_triplet_loss(anchor, positive, negative, extra=True): ...@@ -363,6 +363,7 @@ def soft_triplet_loss(anchor, positive, negative, extra=True):
return loss return loss
# TODO not a good name.
def remove_shape(x, axis, name): def remove_shape(x, axis, name):
""" """
Make the static shape of a tensor less specific, by Make the static shape of a tensor less specific, by
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment