Commit 52a4a0a8 authored by Yuxin Wu's avatar Yuxin Wu

Image to Image

parent d31ba459
...@@ -18,11 +18,13 @@ from GAN import GANTrainer, RandomZData, build_GAN_losses ...@@ -18,11 +18,13 @@ from GAN import GANTrainer, RandomZData, build_GAN_losses
""" """
DCGAN on CelebA dataset. DCGAN on CelebA dataset.
The original code (dcgan.torch) uses kernel_shape=4, but I found the difference not significant.
1. Download the 'aligned&cropped' version of CelebA dataset. 1. Download the 'aligned&cropped' version of CelebA dataset.
2. Start training: 2. Start training:
./celebA.py --data /path/to/image_align_celeba/ ./DCGAN-CelebA.py --data /path/to/image_align_celeba/
3. Visualize samples of a trained model: 3. Visualize samples of a trained model:
./celebA.py --load model.tfmodel --sample ./DCGAN-CelebA.py --load model.tfmodel --sample
""" """
SHAPE = 64 SHAPE = 64
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: Image2Image.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import tensorflow as tf
import glob, pickle
import os, sys
import argparse
import cv2
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses
"""
To train:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain many 512x256 images formed by A and B
To visualize:
./Image2Image.py --data /path/to/test/datadir --mode {AtoB,BtoA} --load pretrained.model
"""
SHAPE = 256
BATCH = 16
IN_CH = 3
OUT_CH = 3
LAMBDA = 100
NF = 64 # number of filter
class Model(ModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input') ,
InputVar(tf.float32, (None, SHAPE, SHAPE, OUT_CH), 'output') ]
def generator(self, imgs):
# imgs: input: 256x256xch
# U-Net structure, slightly different from the original on the location of relu/lrelu
with argscope(BatchNorm, use_local_stat=True), \
argscope(Dropout, is_training=True):
# always use local stat for BN, and apply dropout even in testing
with argscope(Conv2D, kernel_shape=4, stride=2,
nl=lambda x, name: LeakyReLU(BatchNorm('bn', x), name=name)):
e1 = Conv2D('conv1', imgs, NF, nl=LeakyReLU)
e2 = Conv2D('conv2', e1, NF*2)
e3 = Conv2D('conv3', e2, NF*4)
e4 = Conv2D('conv4', e3, NF*8)
e5 = Conv2D('conv5', e4, NF*8)
e6 = Conv2D('conv6', e5, NF*8)
e7 = Conv2D('conv7', e6, NF*8)
e8 = Conv2D('conv8', e7, NF*8, nl=BNReLU) # 1x1
with argscope(Deconv2D, nl=BNReLU, kernel_shape=4, stride=2):
return (LinearWrap(e8)
.Deconv2D('deconv1', NF*8)
.Dropout()
.ConcatWith(3, e7)
.Deconv2D('deconv2', NF*8)
.Dropout()
.ConcatWith(3, e6)
.Deconv2D('deconv3', NF*8)
.Dropout()
.ConcatWith(3, e5)
.Deconv2D('deconv4', NF*8)
.ConcatWith(3, e4)
.Deconv2D('deconv5', NF*4)
.ConcatWith(3, e3)
.Deconv2D('deconv6', NF*2)
.ConcatWith(3, e2)
.Deconv2D('deconv7', NF*1)
.ConcatWith(3, e1)
.Deconv2D('deconv8', OUT_CH, nl=tf.tanh)())
def discriminator(self, inputs, outputs):
""" return a (b, 1) logits"""
l = tf.concat(3, [inputs, outputs])
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2):
l = (LinearWrap(l)
.Conv2D('conv0', NF, nl=LeakyReLU)
.Conv2D('conv1', NF*2)
.BatchNorm('bn1').LeakyReLU()
.Conv2D('conv2', NF*4)
.BatchNorm('bn2').LeakyReLU()
.Conv2D('conv3', NF*8, stride=1) # valid?
.BatchNorm('bn3').LeakyReLU()
.Conv2D('convlast', 1, stride=1)())
return l
def _build_graph(self, input_vars):
input, output = input_vars
input, output = input / 128.0 - 1, output / 128.0 - 1
with argscope([Conv2D, Deconv2D],
W_init=tf.truncated_normal_initializer(stddev=0.02)), \
argscope(LeakyReLU, alpha=0.2):
with tf.variable_scope('gen'):
fake_output = self.generator(input)
with tf.variable_scope('discrim'):
real_pred = self.discriminator(input, output)
with tf.variable_scope('discrim', reuse=True):
fake_pred = self.discriminator(input, fake_output)
self.g_loss, self.d_loss = build_GAN_losses(real_pred, fake_pred)
errL1 = tf.reduce_mean(tf.abs(fake_output - output), name='L1_loss')
self.g_loss = tf.add(self.g_loss, LAMBDA * errL1, name='total_g_loss')
add_moving_summary(errL1, self.g_loss)
# visualization
if IN_CH == 1:
input = tf.image.grayscale_to_rgb(input)
if OUT_CH == 1:
output = tf.image.grayscale_to_rgb(output)
fake_output = tf.image.grayscale_to_rgb(fake_output)
viz = (tf.concat(2, [input, output, fake_output]) + 1.0) * 128.0
viz = tf.cast(viz, tf.uint8, name='viz')
tf.image_summary('gen', viz, max_images=max(30, BATCH))
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
def split_input(img):
"""
img: an 512x256x3 image
:return: [input, output]
"""
input, output = img[:,:256,:], img[:,256:,:]
if args.mode == 'BtoA':
input, output = output, input
if IN_CH == 1:
input = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY)
if OUT_CH == 1:
output = cv2.cvtColor(output, cv2.COLOR_RGB2GRAY)
return [input, output]
def get_data():
datadir = args.data
# assume each image is 512x256 split to left and right
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0]))
augs = [ imgaug.Resize(286), imgaug.RandomCrop(256) ]
ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH)
ds = PrefetchDataZMQ(ds, 1)
return ds
def get_config():
logger.auto_set_dir()
dataset = get_data()
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig(
dataset=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
]),
session_config=get_default_sess_config(0.8),
model=Model(),
step_per_epoch=300,
max_epoch=300,
)
def sample(datadir, model_path):
pred = PredictConfig(
session_init=get_model_loader(model_path),
model=Model(),
input_names=['input', 'output'],
output_names=['viz'])
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = BatchData(MapData(ds, lambda dp: split_input(dp[0])), 16)
pred = SimpleDatasetPredictor(pred, ds)
for o in pred.get_result():
o = o[:,:,:,::-1]
viz = next(build_patch_list(o, nr_row=4, nr_col=4, viz=True))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='A directory of images')
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
global args
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample:
sample(args.data, args.load)
else:
assert args.data
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(config, g_vs_d=1).train()
# Deep Convolutional Generative Adversarial Networks # Generative Adversarial Networks
## DCGAN-CelebA.py
Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch). Reproduce DCGAN following the setup in [dcgan.torch](https://github.com/soumith/dcgan.torch).
...@@ -13,3 +15,8 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv ...@@ -13,3 +15,8 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv
![vec](demo/CelebA-vec.jpg) ![vec](demo/CelebA-vec.jpg)
See the docstring in the script for usage. See the docstring in the script for usage.
## Image2Image.py
Reproduce [Image-to-image Translation with Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf),
following the setup in [pix2pix](https://github.com/phillipi/pix2pix).
...@@ -14,17 +14,14 @@ __all__ = ['DumpParamAsImage'] ...@@ -14,17 +14,14 @@ __all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback): class DumpParamAsImage(Callback):
""" """
Dump a variable to image(s) after every epoch. Dump a variable to image(s) after every epoch to logger.LOG_DIR.
""" """
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False): def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
""" """
:param var_name: the name of the variable. :param var_name: the name of the variable.
:param prefix: the filename prefix for saved images. Default is the op name. :param prefix: the filename prefix for saved images. Default is the op name.
:param map_func: map the value of the variable to an image or list of :param map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity images of shape [h, w] or [h, w, c]. If None, will use identity
:param scale: a multiplier on pixel values, applied after map_func. default to 255 :param scale: a multiplier on pixel values, applied after map_func. default to 255
:param clip: whether to clip the result to [0, 255] :param clip: whether to clip the result to [0, 255]
""" """
......
...@@ -22,6 +22,7 @@ class ImageFromFile(RNGDataFlow): ...@@ -22,6 +22,7 @@ class ImageFromFile(RNGDataFlow):
assert len(files) assert len(files)
self.files = files self.files = files
self.channel = int(channel) self.channel = int(channel)
self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR
self.resize = resize self.resize = resize
self.shuffle = shuffle self.shuffle = shuffle
...@@ -32,8 +33,7 @@ class ImageFromFile(RNGDataFlow): ...@@ -32,8 +33,7 @@ class ImageFromFile(RNGDataFlow):
if self.shuffle: if self.shuffle:
self.rng.shuffle(self.files) self.rng.shuffle(self.files)
for f in self.files: for f in self.files:
im = cv2.imread( im = cv2.imread(f, self.imread_mode)
f, cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR)
if self.channel == 3: if self.channel == 3:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if self.resize is not None: if self.resize is not None:
......
...@@ -66,6 +66,6 @@ def LeakyReLU(x, alpha, name=None): ...@@ -66,6 +66,6 @@ def LeakyReLU(x, alpha, name=None):
@layer_register(log_shape=False, use_scope=False) @layer_register(log_shape=False, use_scope=False)
def BNReLU(x, name=None): def BNReLU(x, name=None):
x = BatchNorm('bn', x, use_local_stat=None) x = BatchNorm('bn', x)
x = tf.nn.relu(x, name=name) x = tf.nn.relu(x, name=name)
return x return x
...@@ -39,7 +39,7 @@ def regularize_cost(regex, func, name=None): ...@@ -39,7 +39,7 @@ def regularize_cost(regex, func, name=None):
return tf.add_n(costs, name=name) return tf.add_n(costs, name=name)
@layer_register(log_shape=False) @layer_register(log_shape=False, use_scope=False)
def Dropout(x, keep_prob=0.5, is_training=None): def Dropout(x, keep_prob=0.5, is_training=None):
""" """
:param is_training: if None, will use the current context by default. :param is_training: if None, will use the current context by default.
......
...@@ -55,6 +55,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -55,6 +55,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
def get_result(self): def get_result(self):
""" A generator to produce prediction for each data""" """ A generator to produce prediction for each data"""
self.dataset.reset_state()
try: try:
sz = self.dataset.size() sz = self.dataset.size()
except NotImplementedError: except NotImplementedError:
......
...@@ -77,7 +77,7 @@ def build_patch_list(patch_list, ...@@ -77,7 +77,7 @@ def build_patch_list(patch_list,
viz=False, lclick_cb=None): viz=False, lclick_cb=None):
""" """
Generate patches. Generate patches.
:param patch_list: bhw or bhwc :param patch_list: bhw or bhwc images in [0,255]
:param border: defaults to 0.1 * max(image_width, image_height) :param border: defaults to 0.1 * max(image_width, image_height)
:param nr_row, nr_col: rows and cols of the grid :param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols :parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment