Commit 5d23f241 authored by Yuxin Wu's avatar Yuxin Wu

fix linting.

parent 19c6b446
...@@ -64,12 +64,12 @@ class Model(GANModelDesc): ...@@ -64,12 +64,12 @@ class Model(GANModelDesc):
GaussianDistribution("uni_a", 1), GaussianDistribution("uni_a", 1),
GaussianDistribution("uni_b", 1)]) GaussianDistribution("uni_b", 1)])
# prior: the assumption how the factors are presented in the dataset # prior: the assumption how the factors are presented in the dataset
prior = tf.constant([0.1] * 10 + [0,0], tf.float32, [12], name='prior') prior = tf.constant([0.1] * 10 + [0, 0], tf.float32, [12], name='prior')
batch_prior = tf.tile(tf.expand_dims(prior, 0), [BATCH, 1], name='batch_prior') batch_prior = tf.tile(tf.expand_dims(prior, 0), [BATCH, 1], name='batch_prior')
# sample the latent code zc: # sample the latent code zc:
sample = self.factors.dists[0].sample( sample = self.factors.dists[0].sample(
BATCH, tf.constant([0.1]*10, tf.float32, shape=[10])) BATCH, tf.constant([0.1] * 10, tf.float32, shape=[10]))
z_cat = symbf.remove_shape(sample, 0, name='z_cat') z_cat = symbf.remove_shape(sample, 0, name='z_cat')
# still sample the latent code from a uniform distribution. # still sample the latent code from a uniform distribution.
z_uni_a = symbf.remove_shape( z_uni_a = symbf.remove_shape(
......
...@@ -257,6 +257,10 @@ class ProductDistribution(Distribution): ...@@ -257,6 +257,10 @@ class ProductDistribution(Distribution):
def param_dim(self): def param_dim(self):
return np.sum([d.param_dim for d in self.dists]) return np.sum([d.param_dim for d in self.dists])
@property
def sample_dim(self):
return np.sum([d.sample_dim for d in self.dists])
def _splitter(self, s, param): def _splitter(self, s, param):
"""Input is split into a list of chunks according """Input is split into a list of chunks according
to dist.param_dim along axis=1 to dist.param_dim along axis=1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment