Commit 8f056dc1 authored by Yuxin Wu's avatar Yuxin Wu

use global namespace between WGAN and DCGAN so that arguments are easier to share

parent 5beab907
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: DCGAN-CelebA.py # File: DCGAN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import glob import glob
...@@ -11,6 +11,8 @@ from tensorpack import * ...@@ -11,6 +11,8 @@ from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils.globvars import globalns as opt
from tensorpack.utils.globvars import use_global_argument
import tensorflow as tf import tensorflow as tf
from GAN import GANTrainer, RandomZData, GANModelDesc from GAN import GANTrainer, RandomZData, GANModelDesc
...@@ -18,25 +20,29 @@ from GAN import GANTrainer, RandomZData, GANModelDesc ...@@ -18,25 +20,29 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
""" """
1. Download the 'aligned&cropped' version of CelebA dataset 1. Download the 'aligned&cropped' version of CelebA dataset
from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
2. Start training: 2. Start training:
./DCGAN-CelebA.py --data /path/to/img_align_celeba/ ./DCGAN-CelebA.py --data /path/to/img_align_celeba/ --crop-size 140
3. Visualize samples of a trained model: Generated samples will be available through tensorboard
3. Visualize samples with an existing model:
./DCGAN-CelebA.py --load path/to/model --sample ./DCGAN-CelebA.py --load path/to/model --sample
You can also train on other images (just use any directory of jpg files in You can also train on other images (just use any directory of jpg files in
`--data`). But you may need to change the preprocessing steps in `get_data()`. `--data`). But you may need to change the preprocessing.
A pretrained model on CelebA is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE A pretrained model on CelebA is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE
""" """
SHAPE = 64 # global vars
BATCH = 128 opt.SHAPE = 64
Z_DIM = 100 opt.BATCH = 128
opt.Z_DIM = 100
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'input')] return [InputDesc(tf.float32, (None, opt.SHAPE, opt.SHAPE, 3), 'input')]
def generator(self, z): def generator(self, z):
""" return an image generated from z""" """ return an image generated from z"""
...@@ -73,8 +79,8 @@ class Model(GANModelDesc): ...@@ -73,8 +79,8 @@ class Model(GANModelDesc):
image_pos = inputs[0] image_pos = inputs[0]
image_pos = image_pos / 128.0 - 1 image_pos = image_pos / 128.0 - 1
z = tf.random_uniform([BATCH, Z_DIM], -1, 1, name='z_train') z = tf.random_uniform([opt.BATCH, opt.Z_DIM], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, Z_DIM], name='z') z = tf.placeholder_with_default(z, [None, opt.Z_DIM], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): W_init=tf.truncated_normal_initializer(stddev=0.02)):
...@@ -93,12 +99,21 @@ class Model(GANModelDesc): ...@@ -93,12 +99,21 @@ class Model(GANModelDesc):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3) return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def get_augmentors():
augs = []
if opt.load_size:
augs.append(imgaug.Resize(opt.load_size))
if opt.crop_size:
augs.append(imgaug.CenterCrop(opt.crop_size))
augs.append(imgaug.Resize(opt.SHAPE))
return augs
def get_data(datadir): def get_data(datadir):
imgs = glob.glob(datadir + '/*.jpg') imgs = glob.glob(datadir + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
augs = [imgaug.CenterCrop(140), imgaug.Resize(64)] ds = AugmentImageComponent(ds, get_augmentors())
ds = AugmentImageComponent(ds, augs) ds = BatchData(ds, opt.BATCH)
ds = BatchData(ds, BATCH)
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
return ds return ds
...@@ -106,10 +121,10 @@ def get_data(datadir): ...@@ -106,10 +121,10 @@ def get_data(datadir):
def get_config(): def get_config():
return TrainConfig( return TrainConfig(
model=Model(), model=Model(),
dataflow=get_data(args.data), dataflow=get_data(opt.data),
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=100,
) )
...@@ -127,15 +142,23 @@ def sample(model_path): ...@@ -127,15 +142,23 @@ def sample(model_path):
viz = stack_patches(o, nr_row=10, nr_col=10, viz=True) viz = stack_patches(o, nr_row=10, nr_col=10, viz=True)
if __name__ == '__main__': def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='view generated examples') parser.add_argument('--sample', action='store_true', help='view generated examples')
parser.add_argument('--data', help='a jpeg directory') parser.add_argument('--data', help='a jpeg directory')
parser.add_argument('--load-size', help='size to load the original images', type=int)
parser.add_argument('--crop-size', help='crop the original images', type=int)
args = parser.parse_args() args = parser.parse_args()
use_global_argument(args)
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
return args
if __name__ == '__main__':
args = get_args()
if args.sample: if args.sample:
sample(args.load) sample(args.load)
else: else:
......
...@@ -21,15 +21,14 @@ from GAN import SeparateGANTrainer, GANModelDesc ...@@ -21,15 +21,14 @@ from GAN import SeparateGANTrainer, GANModelDesc
2. Put list_attr_celeba.txt into that directory as well. 2. Put list_attr_celeba.txt into that directory as well.
3. Start training gender transfer: 3. Start training gender transfer:
./DiscoGAN-CelebA.py --data /path/to/img_align_celeba --style-A Male ./DiscoGAN-CelebA.py --data /path/to/img_align_celeba --style-A Male
4. Visualization on test set to be done. But you can visualize the images in tensorboard now. 4. Visualize the gender conversion images in tensorboard.
With TF1.0.1, cuda 8.0, cudnn 5.1.10, With TF1.0.1, cuda 8.0, cudnn 5.1.10,
the training on 64x64 images of batch 64 runs 5.4 it/s on Tesla M40. the training on 64x64 images of batch 64 runs 5.4 it/s on Tesla M40.
This is 2.4x as fast as the original PyTorch implementation. This is 2.4x as fast as the original PyTorch implementation.
This is surprising to myself, so I'm not sure my comparison is correct.
The cause is probably that in the torch implementation, The cause is probably that in the torch implementation,
a backward() seems to compute gradients for ALL parameters, which is not necessary in GAN. a backward() computes gradients for ALL parameters, which is not necessary in GAN.
""" """
SHAPE = 64 SHAPE = 64
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: WGAN-CelebA.py # File: WGAN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
...@@ -8,15 +8,13 @@ import argparse ...@@ -8,15 +8,13 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G
import tensorflow as tf import tensorflow as tf
from GAN import SeparateGANTrainer from GAN import SeparateGANTrainer
""" """
Wasserstein-GAN. Wasserstein-GAN.
See the docstring in DCGAN-CelebA.py for usage. See the docstring in DCGAN.py for usage.
Actually, just using the clip is enough for WGAN to work (even without BN in generator).
The wasserstein loss is not the key factor.
""" """
# Don't want to mix two examples together, but want to reuse the code. # Don't want to mix two examples together, but want to reuse the code.
...@@ -24,9 +22,13 @@ The wasserstein loss is not the key factor. ...@@ -24,9 +22,13 @@ The wasserstein loss is not the key factor.
import imp import imp
DCGAN = imp.load_source( DCGAN = imp.load_source(
'DCGAN', 'DCGAN',
os.path.join(os.path.dirname(__file__), 'DCGAN-CelebA.py')) os.path.join(os.path.dirname(__file__), 'DCGAN.py'))
G.BATCH = 64
# a hacky way to change loss & optimizer of another script
class Model(DCGAN.Model): class Model(DCGAN.Model):
# def generator(self, z): # def generator(self, z):
# you can override generator to remove BatchNorm, it will still work in WGAN # you can override generator to remove BatchNorm, it will still work in WGAN
...@@ -51,33 +53,20 @@ class Model(DCGAN.Model): ...@@ -51,33 +53,20 @@ class Model(DCGAN.Model):
return optimizer.VariableAssignmentOptimizer(opt, clip) return optimizer.VariableAssignmentOptimizer(opt, clip)
DCGAN.BATCH = 64
DCGAN.Model = Model DCGAN.Model = Model
def get_config():
return TrainConfig(
model=Model(),
# use the same data in the DCGAN example
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=500,
max_epoch=200,
)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() args = DCGAN.get_args()
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='view generated examples')
parser.add_argument('--data', help='a jpeg directory')
args = parser.parse_args()
if args.sample: if args.sample:
DCGAN.sample(args.load) DCGAN.sample(args.load)
else: else:
assert args.data assert args.data
logger.auto_set_dir() logger.auto_set_dir()
config = get_config() config = DCGAN.get_config()
config.steps_per_epoch = 500
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment