Commit 13c96b94 authored by Yuxin Wu's avatar Yuxin Wu

add a WGAN example

parent 00fdb263
......@@ -10,6 +10,8 @@ Reproduce the following GAN-related methods:
+ Conditional GAN
+ [Wasserstein GAN](https://arxiv.org/abs/1701.07875)
Please see the __docstring__ in each script for detailed usage.
## DCGAN-CelebA.py
......@@ -51,3 +53,7 @@ It then maximizes mutual information between these latent variables and the imag
## ConditionalGAN-mnist.py
Train a simple GAN on mnist, conditioned on the class labels.
## WGAN-CelebA.py
Reproduce WGAN by some small modifications on DCGAN-CelebA.py.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: WGAN-CelebA.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import os
import argparse
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from GAN import GANTrainer
"""
Wasserstein-GAN.
See the docstring in DCGAN-CelebA.py for usage.
Actually, just using the clip is enough for WGAN to work (even without BN in generator).
The wasserstein loss is not the key factor.
"""
# Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN-CelebA, and change the batch size & model
import imp
DCGAN = imp.load_source(
'DCGAN',
os.path.join(os.path.dirname(__file__), 'DCGAN-CelebA.py'))
class Model(DCGAN.Model):
# def generator(self, z):
# you can override generator to remove BatchNorm, it will still work in WGAN
def build_losses(self, vecpos, vecneg):
# the Wasserstein-GAN losses
self.d_loss = tf.reduce_mean(vecneg - vecpos, name='d_loss')
self.g_loss = -tf.reduce_mean(vecneg, name='g_loss')
add_moving_summary(self.d_loss, self.g_loss)
def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True)
return tf.train.RMSPropOptimizer(lr)
DCGAN.BATCH = 64
DCGAN.Model = Model
def get_config():
return TrainConfig(
model=Model(),
# use the same data in the DCGAN example
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
steps_per_epoch=300,
max_epoch=200,
)
class WGANTrainer(FeedfreeTrainerBase):
def __init__(self, config):
self._input_method = QueueInput(config.dataflow)
super(WGANTrainer, self).__init__(config)
def _setup(self):
super(WGANTrainer, self)._setup()
self.build_train_tower()
# add clipping to D optimizer
def clip(p):
n = p.op.name
logger.info("Clip {}".format(n))
return tf.clip_by_value(p, -0.01, 0.01)
opt_G = self.model.get_optimizer()
opt_D = optimizer.VariableAssignmentOptimizer(opt_G, clip)
self.d_min = opt_D.minimize(
self.model.d_loss, var_list=self.model.d_vars, name='d_min')
self.g_min = opt_G.minimize(
self.model.g_loss, var_list=self.model.g_vars, name='g_op')
def run_step(self):
for k in range(5):
self.sess.run(self.d_min)
ret = self.sess.run([self.g_min] + self.get_extra_fetches())
return ret[1:]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='view generated examples')
parser.add_argument('--data', help='a jpeg directory')
args = parser.parse_args()
if args.sample:
DCGAN.sample(args.load)
else:
assert args.data
logger.auto_set_dir()
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
WGANTrainer(config).train()
......@@ -17,13 +17,10 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader']
# TODO they initialize_all at the beginning by default.
@six.add_metaclass(ABCMeta)
class SessionInit(object):
""" Base class for utilities to initialize a session. """
def init(self, sess):
"""
Initialize a session
......@@ -40,7 +37,6 @@ class SessionInit(object):
class JustCurrentSession(SessionInit):
""" This is a no-op placeholder"""
def _init(self, sess):
pass
......@@ -49,7 +45,6 @@ class NewSession(SessionInit):
"""
Initialize global variables by their initializer.
"""
def _init(self, sess):
sess.run(tf.global_variables_initializer())
......@@ -62,7 +57,7 @@ class SaverRestore(SessionInit):
def __init__(self, model_path, prefix=None):
"""
Args:
model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
model_path (str): path to the model (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
"""
model_path = get_checkpoint_path(model_path)
......@@ -150,7 +145,7 @@ class ChainInit(SessionInit):
def __init__(self, sess_inits, new_session=True):
"""
Args:
sess_inits (list): list of :class:`SessionInit` instances.
sess_inits (list[SessionInit]): list of :class:`SessionInit` instances.
new_session (bool): add a ``NewSession()`` and the beginning, if
not there.
"""
......
......@@ -45,7 +45,7 @@ def change_env(name, val):
def get_rng(obj=None):
"""
Get a good RNG.
Get a good RNG seeded with time, pid and the object.
Args:
obj: some object to use to generate random seed.
......
......@@ -17,7 +17,7 @@ except ImportError:
pass
__all__ = ['pyplot2img', 'pyplot_viz', 'interactive_imshow',
__all__ = ['pyplot2img', 'interactive_imshow',
'stack_patches', 'gen_stack_patches',
'dump_dataflow_images', 'intensity_to_rgb']
......@@ -34,31 +34,13 @@ def pyplot2img(plt):
return im
def pyplot_viz(img, shape=None):
""" Use pyplot to visualize the image. e.g., when input is grayscale, the result
will automatically have a colormap.
Returns:
np.ndarray: an image.
Note:
this is quite slow. and the returned image will have a border
"""
plt.clf()
plt.axes([0, 0, 1, 1])
plt.imshow(img)
ret = pyplot2img(plt)
if shape is not None:
ret = cv2.resize(ret, shape)
return ret
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
"""
Args:
img (np.ndarray): an image (expect BGR) to show.
lclick_cb: a callback func(img, x, y) for left click event.
lclick_cb, rclick_cb: a callback ``func(img, x, y)`` for left/right click event.
kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to
specify a callback func(img) for keypress.
specify a callback ``func(img)`` for keypress.
Some existing keypress event handler:
......@@ -187,7 +169,7 @@ def stack_patches(
nr_row(int), nr_col(int): rows and cols of the grid.
``nr_col * nr_row`` must be equal to ``len(patch_list)``.
border(int): border length between images.
Defaults to ``0.1 * min(image_w, image_h)``.
Defaults to ``0.1 * min(patch_width, patch_height)``.
pad (boolean): when `patch_list` is a list, pad all patches to the maximum height and width.
This option allows stacking patches of different shapes together.
bgcolor(int or 3-tuple): background color in [0, 255]. Either an int
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment