Commit 9f22aa91 authored by Yuxin Wu's avatar Yuxin Wu

add cyclegan

parent 00811100
...@@ -9,7 +9,7 @@ See some [examples](examples) to learn about the framework: ...@@ -9,7 +9,7 @@ See some [examples](examples) to learn about the framework:
### Vision: ### Vision:
+ [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net) + [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net)
+ [Train ResNet on ImageNet / Cifar10 / SVHN](examples/ResNet) + [Train ResNet on ImageNet / Cifar10 / SVHN](examples/ResNet)
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image. + [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer) + [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Visualize Saliency Maps by Guided ReLU](examples/Saliency) + [Visualize Saliency Maps by Guided ReLU](examples/Saliency)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: CycleGAN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os, sys
import argparse
import glob
from six.moves import map, zip, range
import numpy as np
from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from GAN import GANTrainer, GANModelDesc
"""
1. Download the dataset following the original project: https://github.com/junyanz/CycleGAN#train
2. ./CycleGAN.py --data /path/to/datasets/horse2zebra
Training and testing visuliazations will be in tensorboard.
"""
SHAPE = 256
BATCH = 1
TEST_BATCH = 32
NF = 64 # channel size
def INReLU(x, name=None):
x = InstanceNorm('inorm', x)
return tf.nn.relu(x, name=name)
def INLReLU(x, name=None):
x = InstanceNorm('inorm', x)
return LeakyReLU(x, name=name)
class Model(GANModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputA'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'inputB')]
@staticmethod
def build_res_block(x, name, chan, first=False):
with tf.variable_scope(name):
input = x
return (LinearWrap(x)
.tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
.Conv2D('conv0', chan, padding='VALID')
.tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
.Conv2D('conv1', chan, padding='VALID', nl=tf.identity)
.InstanceNorm('inorm')()) + input
@auto_reuse_variable_scope
def generator(self, img):
assert img is not None
with argscope([Conv2D, Deconv2D], nl=INReLU, kernel_shape=3):
l = (LinearWrap(img)
.tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC')
.Conv2D('conv0', NF, kernel_shape=7, padding='VALID')
.Conv2D('conv1', NF * 2, stride=2)
.Conv2D('conv2', NF * 4, stride=2)())
for k in range(9):
l = Model.build_res_block(l, 'res{}'.format(k), NF * 4, first=(k == 0))
l = (LinearWrap(l)
.Deconv2D('deconv0', NF * 2, stride=2)
.Deconv2D('deconv1', NF * 1, stride=2)
.tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC')
.Conv2D('convlast', 3, kernel_shape=7, padding='VALID', nl=tf.tanh, use_bias=True)())
return l
@auto_reuse_variable_scope
def discriminator(self, img):
with argscope(Conv2D, nl=INLReLU, kernel_shape=4, stride=2):
l = (LinearWrap(img)
.Conv2D('conv0', NF, nl=LeakyReLU)
.Conv2D('conv1', NF * 2)
.Conv2D('conv2', NF * 4)
.Conv2D('conv3', NF * 8, stride=1)
.Conv2D('conv4', 1, stride=1, nl=tf.identity, use_bias=True)())
return l
def _build_graph(self, inputs):
A, B = inputs
A = tf.transpose(A / 128.0 - 1.0, [0, 3, 1, 2])
B = tf.transpose(B / 128.0 - 1.0, [0, 3, 1, 2])
def viz3(name, a, b, c):
im = tf.concat([a, b, c], axis=3)
im = tf.transpose(im, [0, 2, 3, 1])
im = (im + 1.0) * 128
im = tf.clip_by_value(im, 0, 255)
im = tf.cast(im, tf.uint8, name='viz_' + name)
tf.summary.image(name, im, max_outputs=50)
# use the initializers from torch
with argscope([Conv2D, Deconv2D], use_bias=False,
W_init=tf.random_normal_initializer(stddev=0.02)), \
argscope([Conv2D, Deconv2D, InstanceNorm], data_format='NCHW'), \
argscope(LeakyReLU, alpha=0.2):
with tf.variable_scope('gen'):
with tf.variable_scope('B'):
AB = self.generator(A)
with tf.variable_scope('A'):
BA = self.generator(B)
ABA = self.generator(AB)
with tf.variable_scope('B'):
BAB = self.generator(BA)
viz3('A_recon', A, AB, ABA)
viz3('B_recon', B, BA, BAB)
with tf.variable_scope('discrim'):
with tf.variable_scope('A'):
A_dis_real = self.discriminator(A)
A_dis_fake = self.discriminator(BA)
with tf.variable_scope('B'):
B_dis_real = self.discriminator(B)
B_dis_fake = self.discriminator(AB)
def LSGAN_losses(real, fake):
with tf.name_scope('LSGAN_losses'):
d_real = tf.reduce_mean(tf.squared_difference(real, 0.9), name='d_real')
d_fake = tf.reduce_mean(tf.square(fake), name='d_fake')
d_loss = tf.multiply(d_real + d_fake, 0.5, name='d_loss')
g_loss = tf.reduce_mean(tf.squared_difference(fake, 0.9), name='g_loss')
add_moving_summary(g_loss, d_loss)
return g_loss, d_loss
with tf.name_scope('LossA'):
# reconstruction loss
recon_loss_A = tf.reduce_mean(tf.abs(A - ABA), name='recon_loss')
# gan loss
G_loss_A, D_loss_A = LSGAN_losses(A_dis_real, A_dis_fake)
with tf.name_scope('LossB'):
recon_loss_B = tf.reduce_mean(tf.abs(B - BAB), name='recon_loss')
G_loss_B, D_loss_B = LSGAN_losses(B_dis_real, B_dis_fake)
LAMBDA = 10.0
self.g_loss = tf.add((G_loss_A + G_loss_B),
(recon_loss_A + recon_loss_B) * LAMBDA, name='G_loss_total')
self.d_loss = tf.add(D_loss_A, D_loss_B, name='D_loss_total')
self.collect_variables('gen', 'discrim')
add_moving_summary(recon_loss_A, recon_loss_B, self.g_loss, self.d_loss)
def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True)
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def get_data(datadir, isTrain=True):
if isTrain:
augs = [
imgaug.Resize(int(SHAPE * 1.12)),
imgaug.RandomCrop(SHAPE),
]
else:
augs = [imgaug.Resize(SHAPE)]
def get_image_pairs(dir1, dir2):
def get_df(dir):
files = sorted(glob.glob(os.path.join(dir, '*.jpg')))
df = ImageFromFile(files, channel=3, shuffle=isTrain)
return AugmentImageComponent(df, augs)
return JoinData([get_df(dir1), get_df(dir2)])
names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB']
df = get_image_pairs(*[os.path.join(datadir, n) for n in names])
df = BatchData(df, BATCH if isTrain else TEST_BATCH)
df = PrefetchDataZMQ(df, 2 if isTrain else 1)
return df
class VisualizeTestSet(Callback):
def _setup_graph(self):
self.pred = self.trainer.get_predictor(
['inputA', 'inputB'], ['viz_A_recon', 'viz_B_recon'])
def _before_train(self):
global args
self.val_ds = get_data(args.data, isTrain=False)
def _trigger(self):
idx = 0
for iA, iB in self.val_ds.get_data():
vizA, vizB = self.pred(iA, iB)
self.trainer.monitors.put_image('testA-{}'.format(idx), vizA)
self.trainer.monitors.put_image('testB-{}'.format(idx), vizB)
idx += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data', required=True,
help='the image directory. should contain trainA/trainB/testA/testB')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
logger.auto_set_dir()
data = get_data(args.data)
data = PrintData(data)
config = TrainConfig(
model=Model(),
dataflow=data,
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter(
'learning_rate',
[(100, 2e-4), (200, 0)], interp='linear'),
PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3),
],
max_epoch=195,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(config).train()
...@@ -109,6 +109,9 @@ class SeparateGANTrainer(FeedfreeTrainerBase): ...@@ -109,6 +109,9 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase): class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
def __init__(self, config): def __init__(self, config):
super(MultiGPUGANTrainer, self).__init__(config) super(MultiGPUGANTrainer, self).__init__(config)
self._nr_gpu = config.nr_tower self._nr_gpu = config.nr_tower
......
...@@ -18,6 +18,9 @@ Reproduce the following GAN-related methods: ...@@ -18,6 +18,9 @@ Reproduce the following GAN-related methods:
+ BEGAN ([BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717)) + BEGAN ([BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717))
+ CycleGAN ([Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593))
Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported. Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
## DCGAN.py ## DCGAN.py
...@@ -65,6 +68,8 @@ Some BEGAN samples: ...@@ -65,6 +68,8 @@ Some BEGAN samples:
![began-sample](demo/BEGAN-CelebA-samples.jpg) ![began-sample](demo/BEGAN-CelebA-samples.jpg)
## DiscoGAN-CelebA.py ## CycleGAN.py, DiscoGAN-CelebA.py
Reproduce CycleGAN with the original datasets, and DiscoGAN on CelebA. They are pretty much the same idea with different architecture.
Reproduce DiscoGAN on CelebA. ![cyclegan-sample](demo/CycleGAN-horse2zebra.jpg)
...@@ -12,7 +12,7 @@ Training examples with __reproducible__ and meaningful performance. ...@@ -12,7 +12,7 @@ Training examples with __reproducible__ and meaningful performance.
+ [A tiny SVHN ConvNet with 97.8% accuracy](svhn-digit-convnet.py) + [A tiny SVHN ConvNet with 97.8% accuracy](svhn-digit-convnet.py)
+ [DoReFa-Net: training binary / low-bitwidth CNN on ImageNet](DoReFa-Net) + [DoReFa-Net: training binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [Train ResNet for ImageNet/Cifar10/SVHN](ResNet) + [Train ResNet for ImageNet/Cifar10/SVHN](ResNet)
+ [Generative Adversarial Network(GAN) variants](GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image. + [Generative Adversarial Network(GAN) variants](GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+ [Inception-BN with 71% accuracy](Inception/inception-bn.py) + [Inception-BN with 71% accuracy](Inception/inception-bn.py)
+ [InceptionV3 with 74% accuracy (similar to the official code)](Inception/inceptionv3.py) + [InceptionV3 with 74% accuracy (similar to the official code)](Inception/inceptionv3.py)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](HED)
......
...@@ -578,7 +578,7 @@ class CacheData(ProxyDataFlow): ...@@ -578,7 +578,7 @@ class CacheData(ProxyDataFlow):
class PrintData(ProxyDataFlow): class PrintData(ProxyDataFlow):
""" """
Behave like an identity mapping but print shapes of produced datapoints once during construction. Behave like an identity mapping but print shape and range of the first datapoint once during construction.
Attributes: Attributes:
label (str): label to identify the data when using this debugging on multiple places. label (str): label to identify the data when using this debugging on multiple places.
......
...@@ -29,7 +29,7 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', ...@@ -29,7 +29,7 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
def _check_tf_version(): def _check_tf_version():
ver = float('.'.join(tf.VERSION.split('.')[:2])) ver = float('.'.join(tf.VERSION.split('.')[:2]))
assert ver >= 1.1, "TF version {} is too old to run multi GPU training!".format(ver) assert ver >= 1.1, "TF version {} is too old to run multi GPU training!".format(tf.VERSION)
def apply_prefetch_policy(config, use_stage=True): def apply_prefetch_policy(config, use_stage=True):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment