Commit bd3df33d authored by Yuxin Wu's avatar Yuxin Wu

add some began samples.

parent 085190b6
...@@ -17,6 +17,8 @@ from GAN import GANModelDesc, GANTrainer ...@@ -17,6 +17,8 @@ from GAN import GANModelDesc, GANTrainer
""" """
Boundary Equilibrium GAN. Boundary Equilibrium GAN.
See the docstring in DCGAN.py for usage. See the docstring in DCGAN.py for usage.
A pretrained model on CelebA is at https://drive.google.com/open?id=0B5uDfUQ1JTglUmgyZV8zQmNOTVU
""" """
...@@ -89,7 +91,6 @@ class Model(GANModelDesc): ...@@ -89,7 +91,6 @@ class Model(GANModelDesc):
def summary_image(name, x): def summary_image(name, x):
x = (x + 1.0) * 128.0 x = (x + 1.0) * 128.0
x = tf.clip_by_value(x, 0, 255) x = tf.clip_by_value(x, 0, 255)
x = tf.cast(x, tf.uint8)
tf.summary.image(name, x, max_outputs=30) tf.summary.image(name, x, max_outputs=30)
with argscope([Conv2D, FullyConnected], with argscope([Conv2D, FullyConnected],
...@@ -153,7 +154,7 @@ def get_config(): ...@@ -153,7 +154,7 @@ def get_config():
if __name__ == '__main__': if __name__ == '__main__':
args = DCGAN.get_args() args = DCGAN.get_args()
if args.sample: if args.sample:
DCGAN.sample(args.load) DCGAN.sample(args.load, 'gen/conv4.3/output')
else: else:
assert args.data assert args.data
logger.auto_set_dir() logger.auto_set_dir()
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import glob import glob
import numpy as np
import os, sys import os, sys
import argparse import argparse
...@@ -128,16 +129,17 @@ def get_config(): ...@@ -128,16 +129,17 @@ def get_config():
) )
def sample(model_path): def sample(model_path, output_name='gen/gen'):
pred = PredictConfig( pred = PredictConfig(
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
model=Model(), model=Model(),
input_names=['z'], input_names=['z'],
output_names=['gen/gen', 'z']) output_names=[output_name, 'z'])
pred = SimpleDatasetPredictor(pred, RandomZData((100, 100))) pred = SimpleDatasetPredictor(pred, RandomZData((100, opt.Z_DIM)))
for o in pred.get_result(): for o in pred.get_result():
o, zs = o[0] + 1, o[1] o, zs = o[0] + 1, o[1]
o = o * 128.0 o = o * 128.0
o = np.clip(o, 0, 255)
o = o[:, :, :, ::-1] o = o[:, :, :, ::-1]
viz = stack_patches(o, nr_row=10, nr_col=10, viz=True) viz = stack_patches(o, nr_row=10, nr_col=10, viz=True)
......
...@@ -26,11 +26,11 @@ Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/ ...@@ -26,11 +26,11 @@ Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/
+ Generated samples + Generated samples
![sample](demo/CelebA-samples.jpg) ![sample](demo/DCGAN-CelebA-samples.jpg)
+ Vector arithmetic: smiling woman - neutral woman + neutral man = smiling man + Vector arithmetic: smiling woman - neutral woman + neutral man = smiling man
![vec](demo/CelebA-vec.jpg) ![vec](demo/DCGAN-CelebA-vec.jpg)
## Image2Image.py ## Image2Image.py
...@@ -61,6 +61,9 @@ Train a simple GAN on mnist, conditioned on the class labels. ...@@ -61,6 +61,9 @@ Train a simple GAN on mnist, conditioned on the class labels.
## WGAN.py, Improved-WGAN.py, BEGAN.py ## WGAN.py, Improved-WGAN.py, BEGAN.py
These variants are implemented by some small modifications on top of DCGAN.py. These variants are implemented by some small modifications on top of DCGAN.py.
Some BEGAN samples:
![began-sample](demo/BEGAN-CelebA-samples.jpg)
## DiscoGAN-CelebA.py ## DiscoGAN-CelebA.py
......
...@@ -209,7 +209,7 @@ class TestPool(TestModel): ...@@ -209,7 +209,7 @@ class TestPool(TestModel):
self.assertTrue((res == 0).all()) self.assertTrue((res == 0).all())
def test_BilinearUpSample(self): def test_BilinearUpSample(self):
h, w = 5, 5 h, w = 12, 12
scale = 2 scale = 2
mat = np.random.rand(h, w).astype('float32') mat = np.random.rand(h, w).astype('float32')
...@@ -224,11 +224,6 @@ class TestPool(TestModel): ...@@ -224,11 +224,6 @@ class TestPool(TestModel):
diff = np.abs(res2 - res) diff = np.abs(res2 - res)
# TODO not equivalent to rescale on edge?
#diff[0, :] = 0
#diff[-1, :] = 0
#diff[:, 0] = 0
#diff[:, -1] = 0
# if not diff.max() < 1e-4: # if not diff.max() < 1e-4:
# import IPython # import IPython
# IPython.embed(config=IPython.terminal.ipapp.load_default_config()) # IPython.embed(config=IPython.terminal.ipapp.load_default_config())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment