Commit 8f056dc1 authored by Yuxin Wu's avatar Yuxin Wu

use global namespace between WGAN and DCGAN so that arguments are easier to share

parent 5beab907
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: DCGAN-CelebA.py
# File: DCGAN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import glob
......@@ -11,6 +11,8 @@ from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils.globvars import globalns as opt
from tensorpack.utils.globvars import use_global_argument
import tensorflow as tf
from GAN import GANTrainer, RandomZData, GANModelDesc
......@@ -18,25 +20,29 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
"""
1. Download the 'aligned&cropped' version of CelebA dataset
from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
2. Start training:
./DCGAN-CelebA.py --data /path/to/img_align_celeba/
3. Visualize samples of a trained model:
./DCGAN-CelebA.py --data /path/to/img_align_celeba/ --crop-size 140
Generated samples will be available through tensorboard
3. Visualize samples with an existing model:
./DCGAN-CelebA.py --load path/to/model --sample
You can also train on other images (just use any directory of jpg files in
`--data`). But you may need to change the preprocessing steps in `get_data()`.
`--data`). But you may need to change the preprocessing.
A pretrained model on CelebA is at https://drive.google.com/open?id=0B9IPQTvr2BBkLUF2M0RXU1NYSkE
"""
SHAPE = 64
BATCH = 128
Z_DIM = 100
# global vars
opt.SHAPE = 64
opt.BATCH = 128
opt.Z_DIM = 100
class Model(GANModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'input')]
return [InputDesc(tf.float32, (None, opt.SHAPE, opt.SHAPE, 3), 'input')]
def generator(self, z):
""" return an image generated from z"""
......@@ -73,8 +79,8 @@ class Model(GANModelDesc):
image_pos = inputs[0]
image_pos = image_pos / 128.0 - 1
z = tf.random_uniform([BATCH, Z_DIM], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, Z_DIM], name='z')
z = tf.random_uniform([opt.BATCH, opt.Z_DIM], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, opt.Z_DIM], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)):
......@@ -93,12 +99,21 @@ class Model(GANModelDesc):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def get_augmentors():
augs = []
if opt.load_size:
augs.append(imgaug.Resize(opt.load_size))
if opt.crop_size:
augs.append(imgaug.CenterCrop(opt.crop_size))
augs.append(imgaug.Resize(opt.SHAPE))
return augs
def get_data(datadir):
imgs = glob.glob(datadir + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True)
augs = [imgaug.CenterCrop(140), imgaug.Resize(64)]
ds = AugmentImageComponent(ds, augs)
ds = BatchData(ds, BATCH)
ds = AugmentImageComponent(ds, get_augmentors())
ds = BatchData(ds, opt.BATCH)
ds = PrefetchDataZMQ(ds, 1)
return ds
......@@ -106,10 +121,10 @@ def get_data(datadir):
def get_config():
return TrainConfig(
model=Model(),
dataflow=get_data(args.data),
dataflow=get_data(opt.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
max_epoch=100,
)
......@@ -127,15 +142,23 @@ def sample(model_path):
viz = stack_patches(o, nr_row=10, nr_col=10, viz=True)
if __name__ == '__main__':
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='view generated examples')
parser.add_argument('--data', help='a jpeg directory')
parser.add_argument('--load-size', help='size to load the original images', type=int)
parser.add_argument('--crop-size', help='crop the original images', type=int)
args = parser.parse_args()
use_global_argument(args)
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
return args
if __name__ == '__main__':
args = get_args()
if args.sample:
sample(args.load)
else:
......
......@@ -21,15 +21,14 @@ from GAN import SeparateGANTrainer, GANModelDesc
2. Put list_attr_celeba.txt into that directory as well.
3. Start training gender transfer:
./DiscoGAN-CelebA.py --data /path/to/img_align_celeba --style-A Male
4. Visualization on test set to be done. But you can visualize the images in tensorboard now.
4. Visualize the gender conversion images in tensorboard.
With TF1.0.1, cuda 8.0, cudnn 5.1.10,
the training on 64x64 images of batch 64 runs 5.4 it/s on Tesla M40.
This is 2.4x as fast as the original PyTorch implementation.
This is surprising to myself, so I'm not sure my comparison is correct.
The cause is probably that in the torch implementation,
a backward() seems to compute gradients for ALL parameters, which is not necessary in GAN.
a backward() computes gradients for ALL parameters, which is not necessary in GAN.
"""
SHAPE = 64
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: WGAN-CelebA.py
# File: WGAN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
......@@ -8,15 +8,13 @@ import argparse
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G
import tensorflow as tf
from GAN import SeparateGANTrainer
"""
Wasserstein-GAN.
See the docstring in DCGAN-CelebA.py for usage.
Actually, just using the clip is enough for WGAN to work (even without BN in generator).
The wasserstein loss is not the key factor.
See the docstring in DCGAN.py for usage.
"""
# Don't want to mix two examples together, but want to reuse the code.
......@@ -24,9 +22,13 @@ The wasserstein loss is not the key factor.
import imp
DCGAN = imp.load_source(
'DCGAN',
os.path.join(os.path.dirname(__file__), 'DCGAN-CelebA.py'))
os.path.join(os.path.dirname(__file__), 'DCGAN.py'))
G.BATCH = 64
# a hacky way to change loss & optimizer of another script
class Model(DCGAN.Model):
# def generator(self, z):
# you can override generator to remove BatchNorm, it will still work in WGAN
......@@ -51,33 +53,20 @@ class Model(DCGAN.Model):
return optimizer.VariableAssignmentOptimizer(opt, clip)
DCGAN.BATCH = 64
DCGAN.Model = Model
def get_config():
return TrainConfig(
model=Model(),
# use the same data in the DCGAN example
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=500,
max_epoch=200,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='view generated examples')
parser.add_argument('--data', help='a jpeg directory')
args = parser.parse_args()
args = DCGAN.get_args()
if args.sample:
DCGAN.sample(args.load)
else:
assert args.data
logger.auto_set_dir()
config = get_config()
config = DCGAN.get_config()
config.steps_per_epoch = 500
if args.load:
config.session_init = SaverRestore(args.load)
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment