Commit e1278514 authored by Yuxin Wu's avatar Yuxin Wu

fix linting and small change in examples

parent 1a931e90
...@@ -60,8 +60,7 @@ class Model(GANModelDesc): ...@@ -60,8 +60,7 @@ class Model(GANModelDesc):
# latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM) # latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10), self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10),
GaussianDistribution("uni_a", 1), GaussianDistribution("uni_a", 1),
GaussianDistribution("uni_b", 1), GaussianDistribution("uni_b", 1)])
])
# sample the latent code zc: # sample the latent code zc:
idxs = tf.squeeze(tf.multinomial(tf.zeros([BATCH, 10]), 1), 1) idxs = tf.squeeze(tf.multinomial(tf.zeros([BATCH, 10]), 1), 1)
......
...@@ -176,7 +176,7 @@ def get_data(train_or_test): ...@@ -176,7 +176,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, min(12, multiprocessing.cpu_count())) ds = PrefetchDataZMQ(ds, min(30, multiprocessing.cpu_count()))
return ds return ds
......
...@@ -17,7 +17,7 @@ A small convnet model for Cifar10 or Cifar100 dataset. ...@@ -17,7 +17,7 @@ A small convnet model for Cifar10 or Cifar100 dataset.
Cifar10: Cifar10:
91% accuracy after 50k step. 91% accuracy after 50k step.
19.3 step/s on Tesla M40 30 step/s on TitanX
Not a good model for Cifar100, just for demonstration. Not a good model for Cifar100, just for demonstration.
""" """
...@@ -98,7 +98,7 @@ def get_data(train_or_test, cifar_classnum): ...@@ -98,7 +98,7 @@ def get_data(train_or_test, cifar_classnum):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchData(ds, 3, 2) ds = PrefetchDataZMQ(ds, 3)
return ds return ds
......
...@@ -362,6 +362,7 @@ def soft_triplet_loss(anchor, positive, negative, extra=True): ...@@ -362,6 +362,7 @@ def soft_triplet_loss(anchor, positive, negative, extra=True):
else: else:
return loss return loss
def remove_shape(x, axis, name): def remove_shape(x, axis, name):
""" """
Make the static shape of a tensor less specific, by Make the static shape of a tensor less specific, by
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment