Commit 15befae4 authored by Yuxin Wu's avatar Yuxin Wu

implement the new distribution used by infogan.

parent 5d23f241
...@@ -21,6 +21,16 @@ BATCH = 128 ...@@ -21,6 +21,16 @@ BATCH = 128
NOISE_DIM = 62 NOISE_DIM = 62
class GaussianWithUniformSample(GaussianDistribution):
"""
OpenAI official code actually models the "uniform" latent code as
a Gaussian distribution, but obtain the samples from a uniform distribution.
We follow the official code for now.
"""
def _sample(self, batch_size, theta):
return tf.random_uniform([batch_size, self.dim], -1, 1)
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
...@@ -58,27 +68,15 @@ class Model(GANModelDesc): ...@@ -58,27 +68,15 @@ class Model(GANModelDesc):
real_sample = tf.expand_dims(real_sample * 2.0 - 1, -1) real_sample = tf.expand_dims(real_sample * 2.0 - 1, -1)
# latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM) # latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
# OpenAI code actually uses Gaussian distribution for uniform, except
# in the sample step. We follow the official implementation for now.
self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10), self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10),
GaussianDistribution("uni_a", 1), GaussianWithUniformSample("uni_a", 1),
GaussianDistribution("uni_b", 1)]) GaussianWithUniformSample("uni_b", 1)])
# prior: the assumption how the factors are presented in the dataset # prior: the assumption how the factors are presented in the dataset
prior = tf.constant([0.1] * 10 + [0, 0], tf.float32, [12], name='prior') prior = tf.constant([0.1] * 10 + [0, 0], tf.float32, [12], name='prior')
batch_prior = tf.tile(tf.expand_dims(prior, 0), [BATCH, 1], name='batch_prior') batch_prior = tf.tile(tf.expand_dims(prior, 0), [BATCH, 1], name='batch_prior')
# sample the latent code zc: # sample the latent code zc:
sample = self.factors.dists[0].sample( zc = symbf.remove_shape(self.factors.sample(BATCH, prior), 0, name='z_code')
BATCH, tf.constant([0.1] * 10, tf.float32, shape=[10]))
z_cat = symbf.remove_shape(sample, 0, name='z_cat')
# still sample the latent code from a uniform distribution.
z_uni_a = symbf.remove_shape(
tf.random_uniform([BATCH, 1], -1, 1), 0, name='z_uni_a')
z_uni_b = symbf.remove_shape(
tf.random_uniform([BATCH, 1], -1, 1), 0, name='z_uni_b')
zc = tf.concat_v2([z_cat, z_uni_a, z_uni_b], 1, name='z_code')
# TODO ideally this can be done by self.factors.sample, if sample
# method is consistent with the distribution
z_noise = symbf.remove_shape( z_noise = symbf.remove_shape(
tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise') tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise')
...@@ -175,7 +173,7 @@ def sample(model_path): ...@@ -175,7 +173,7 @@ def sample(model_path):
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
model=Model(), model=Model(),
input_names=['z_cat', 'z_uni_a', 'z_uni_b', 'z_noise'], input_names=['z_code', 'z_noise'],
output_names=['gen/viz'])) output_names=['gen/viz']))
# sample all one-hot encodings (10 times) # sample all one-hot encodings (10 times)
...@@ -189,17 +187,20 @@ def sample(model_path): ...@@ -189,17 +187,20 @@ def sample(model_path):
while True: while True:
# only categorical turned on # only categorical turned on
z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM)) z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM))
o = pred([z_cat, z_uni * 0, z_uni * 0, z_noise])[0] zc = np.concatenate((z_cat, z_uni * 0, z_uni * 0), axis=1)
o = pred(zc, z_noise)[0]
viz1 = next(build_patch_list(o, nr_row=10, nr_col=10)) viz1 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz1 = cv2.resize(viz1, (IMG_SIZE, IMG_SIZE)) viz1 = cv2.resize(viz1, (IMG_SIZE, IMG_SIZE))
# show effect of first continous variable with fixed noise # show effect of first continous variable with fixed noise
o = pred([z_cat, z_uni, z_uni * 0, z_noise * 0])[0] zc = np.concatenate((z_cat, z_uni, z_uni * 0), axis=1)
o = pred(zc, z_noise * 0)[0]
viz2 = next(build_patch_list(o, nr_row=10, nr_col=10)) viz2 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz2 = cv2.resize(viz2, (IMG_SIZE, IMG_SIZE)) viz2 = cv2.resize(viz2, (IMG_SIZE, IMG_SIZE))
# show effect of second continous variable with fixed noise # show effect of second continous variable with fixed noise
o = pred([z_cat, z_uni * 0, z_uni, z_noise * 0])[0] zc = np.concatenate((z_cat, z_uni * 0, z_uni), axis=1)
o = pred(zc, z_noise * 0)[0]
viz3 = next(build_patch_list(o, nr_row=10, nr_col=10)) viz3 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz3 = cv2.resize(viz3, (IMG_SIZE, IMG_SIZE)) viz3 = cv2.resize(viz3, (IMG_SIZE, IMG_SIZE))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment