Commit 9d6197aa authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

GAN refactor (#107)

* initial public refactor

* add uniform

* simplify code

* split uniform-factor

* fix prior

* working version

* more documentation

* clean up done for InfoGAN

* updated other examples to match changes in GAN.py

* final edits from my side

* docs changes in base class

* more docs and interface name

* some changes in name scope and docs

* use property for _dim

* visualization & gradient scale according to the paper

* update the demo img
parent 677af274
......@@ -17,7 +17,7 @@ from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
from tensorpack.utils.globvars import globalns as CFG, use_global_argument
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses
from GAN import GANTrainer, RandomZData, GANModelDesc
"""
DCGAN on CelebA dataset.
......@@ -35,7 +35,7 @@ CFG.BATCH = 128
CFG.Z_DIM = 100
class Model(ModelDesc):
class Model(GANModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, (None, CFG.SHAPE, CFG.SHAPE, 3), 'input')]
......@@ -87,10 +87,8 @@ class Model(ModelDesc):
with tf.variable_scope('discrim', reuse=True):
vecneg = self.discriminator(image_gen)
self.g_loss, self.d_loss = build_GAN_losses(vecpos, vecneg)
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
self.build_losses(vecpos, vecneg)
self.collect_variables()
def get_data():
......
......@@ -7,11 +7,68 @@ import tensorflow as tf
import numpy as np
import time
from tensorpack import (FeedfreeTrainerBase, TowerContext,
get_global_step_var, QueueInput)
get_global_step_var, QueueInput, ModelDesc)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
from tensorpack.tfutils.gradproc import apply_grad_processors, CheckGradient
from tensorpack.dataflow import DataFlow
class GANModelDesc(ModelDesc):
def collect_variables(self):
"""Extract variables by prefix
"""
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
def build_losses(self, logits_real, logits_fake):
"""D and G play two-player minimax game with value function V(G,D)
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Note, we swap 0, 1 labels as suggested in "Improving GANs".
Args:
logits_real (tf.Tensor): discrim logits from real samples
logits_fake (tf.Tensor): discrim logits from fake samples produced by generator
Returns:
tf.Tensor: Description
"""
with tf.name_scope("GAN_loss"):
score_real = tf.sigmoid(logits_real)
score_fake = tf.sigmoid(logits_fake)
tf.summary.histogram('score-real', score_real)
tf.summary.histogram('score-fake', score_fake)
with tf.name_scope("discrim"):
d_loss_pos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_real, labels=tf.zeros_like(logits_real)), name='loss_real')
d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.ones_like(logits_fake)), name='loss_fake')
d_pos_acc = tf.reduce_mean(tf.cast(score_real < 0.5, tf.float32), name='accuracy_real')
d_neg_acc = tf.reduce_mean(tf.cast(score_fake > 0.5, tf.float32), name='accuracy_fake')
self.d_accuracy = tf.add(.5 * d_pos_acc, .5 * d_neg_acc, name='accuracy')
self.d_loss = tf.add(.5 * d_loss_pos, .5 * d_loss_neg, name='loss')
with tf.name_scope("gen"):
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.zeros_like(logits_fake)), name='loss')
self.g_accuracy = tf.reduce_mean(tf.cast(score_fake < 0.5, tf.float32), name='accuracy')
add_moving_summary(self.g_loss, self.d_loss, self.d_accuracy, self.g_accuracy)
def get_gradient_processor_g(self):
return [CheckGradient()]
def get_gradient_processor_d(self):
return [CheckGradient()]
class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config):
self._input_method = QueueInput(config.dataflow)
......@@ -22,11 +79,18 @@ class GANTrainer(FeedfreeTrainerBase):
with TowerContext(''):
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
self.g_min = self.config.optimizer.minimize(self.model.g_loss,
var_list=self.model.g_vars, name='g_op')
grads = self.config.optimizer.compute_gradients(
self.model.g_loss, var_list=self.model.g_vars)
grads = apply_grad_processors(
grads, self.model.get_gradient_processor_g())
self.g_min = self.config.optimizer.apply_gradients(grads, name='g_op')
with tf.control_dependencies([self.g_min]):
self.d_min = self.config.optimizer.minimize(self.model.d_loss,
var_list=self.model.d_vars, name='d_op')
grads = self.config.optimizer.compute_gradients(
self.model.d_loss, var_list=self.model.d_vars)
grads = apply_grad_processors(
grads, self.model.get_gradient_processor_d())
self.d_min = self.config.optimizer.apply_gradients(grads, name='d_op')
self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr')
self.summary_op = summary_moving_average()
self.train_op = tf.group(self.d_min, self.summary_op, self.gs_incr)
......@@ -44,31 +108,3 @@ class RandomZData(DataFlow):
def get_data(self):
while True:
yield [np.random.uniform(-1, 1, size=self.shape)]
def build_GAN_losses(vecpos, vecneg):
"""
:param vecpos, vecneg: output of the discriminator (logits) for real
and fake images.
:return: (loss of G, loss of D)
"""
sigmpos = tf.sigmoid(vecpos)
sigmneg = tf.sigmoid(vecneg)
tf.summary.histogram('sigmoid-pos', sigmpos)
tf.summary.histogram('sigmoid-neg', sigmneg)
d_loss_pos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=vecpos, labels=tf.ones_like(vecpos)), name='d_CE_loss_pos')
d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=vecneg, labels=tf.zeros_like(vecneg)), name='d_CE_loss_neg')
d_pos_acc = tf.reduce_mean(tf.cast(sigmpos > 0.5, tf.float32), name='pos_acc')
d_neg_acc = tf.reduce_mean(tf.cast(sigmneg < 0.5, tf.float32), name='neg_acc')
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=vecneg, labels=tf.ones_like(vecneg)), name='g_CE_loss')
d_loss = tf.add(d_loss_pos, d_loss_neg, name='d_CE_loss')
add_moving_summary(d_loss_pos, d_loss_neg,
g_loss, d_loss,
d_pos_acc, d_neg_acc)
return g_loss, d_loss
......@@ -16,7 +16,7 @@ from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, build_GAN_losses
from GAN import GANTrainer, GANModelDesc
"""
To train:
......@@ -42,7 +42,7 @@ LAMBDA = 100
NF = 64 # number of filter
class Model(ModelDesc):
class Model(GANModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'),
......@@ -114,7 +114,7 @@ class Model(ModelDesc):
with tf.variable_scope('discrim', reuse=True):
fake_pred = self.discriminator(input, fake_output)
self.g_loss, self.d_loss = build_GAN_losses(real_pred, fake_pred)
self.build_losses(real_pred, fake_pred)
errL1 = tf.reduce_mean(tf.abs(fake_output - output), name='L1_loss')
self.g_loss = tf.add(self.g_loss, LAMBDA * errL1, name='total_g_loss')
add_moving_summary(errL1, self.g_loss)
......@@ -129,9 +129,7 @@ class Model(ModelDesc):
viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
tf.summary.image('input,output,fake', viz, max_outputs=max(30, BATCH))
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
self.collect_variables()
def split_input(img):
......
......@@ -12,13 +12,16 @@ import argparse
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.distributions import *
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, build_GAN_losses
from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient
from GAN import GANTrainer, GANModelDesc
BATCH = 128
NOISE_DIM = 62
class Model(ModelDesc):
class Model(GANModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, (None, 28, 28), 'input')]
......@@ -29,13 +32,12 @@ class Model(ModelDesc):
l = tf.reshape(l, [-1, 7, 7, 128])
l = Deconv2D('deconv1', l, [14, 14, 64], 4, 2, nl=BNReLU)
l = Deconv2D('deconv2', l, [28, 28, 1], 4, 2, nl=tf.identity)
l = tf.nn.tanh(l, name='gen')
l = tf.tanh(l, name='gen')
return l
def discriminator(self, imgs):
""" return a (b, 1) logits"""
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \
argscope(LeakyReLU, alpha=0.2):
argscope(LeakyReLU, alpha=0.1):
l = (LinearWrap(imgs)
.Conv2D('conv0', 64)
.LeakyReLU()
......@@ -48,49 +50,57 @@ class Model(ModelDesc):
encoder = (LinearWrap(l)
.FullyConnected('fce1', 128, nl=tf.identity)
.BatchNorm('bne').LeakyReLU()
.FullyConnected('fce-out', 10, nl=tf.identity)())
.FullyConnected('fce-out', self.factors.param_dim, nl=tf.identity)())
return logits, encoder
def _build_graph(self, input_vars):
image_pos = input_vars[0]
image_pos = tf.expand_dims(image_pos * 2.0 - 1, -1)
real_sample = input_vars[0]
real_sample = tf.expand_dims(real_sample * 2.0 - 1, -1)
prior_prob = tf.constant([0.1] * 10, name='prior_prob')
# assume first 10 is categorical
ids = tf.multinomial(tf.zeros([BATCH, 10]), num_samples=1)[:, 0]
zc = tf.one_hot(ids, 10, name='zc_train')
zc = tf.placeholder_with_default(zc, [None, 10], name='zc')
# latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10),
GaussianDistributionUniformPrior("uni_a", 1),
GaussianDistributionUniformPrior("uni_b", 1),
NoiseDistribution("noise", NOISE_DIM)])
z = tf.random_uniform(tf.stack([tf.shape(zc)[0], 90]), -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, 90], name='z')
z = tf.concat_v2([zc, z], 1, name='fullz')
z = self.factors.sample_prior(BATCH, name='zc')
with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'):
image_gen = self.generator(z)
tf.summary.image('gen', image_gen, max_outputs=30)
fake_sample = self.generator(z)
fake_sample_viz = tf.cast((fake_sample + 1) * 128.0, tf.uint8, name='viz')
tf.summary.image('gen', fake_sample_viz, max_outputs=30)
# TODO investigate how bn stats should be updated across two discrim
with tf.variable_scope('discrim'):
vecpos, _ = self.discriminator(image_pos)
real_pred, _ = self.discriminator(real_sample)
with tf.variable_scope('discrim', reuse=True):
vecneg, dist_param = self.discriminator(image_gen)
logprob = tf.nn.log_softmax(dist_param) # log prob of each category
fake_pred, dist_param = self.discriminator(fake_sample)
# post-process all dist_params from discriminator
encoder_activation = self.factors.encoder_activation(dist_param)
with tf.name_scope("mutual_information"):
MIs = self.factors.mutual_information(z, encoder_activation)
mi = tf.add_n(MIs, name="total")
summary.add_moving_summary(MIs + [mi])
# default GAN objective
self.build_losses(real_pred, fake_pred)
# Q(c|x) = Q(zc | image_gen)
log_qc_given_x = tf.reduce_sum(logprob * zc, 1, name='logQc_x') # bx1
log_qc = tf.reduce_sum(prior_prob * zc, 1, name='logQc')
Elog_qc_given_x = tf.reduce_mean(log_qc_given_x, name='ElogQc_x')
Hc = tf.reduce_mean(-log_qc, name='Hc')
MIloss = tf.multiply(Hc + Elog_qc_given_x, -1.0, name='neg_MI')
# subtract mutual information for latent factores (we want to maximize them)
self.g_loss = tf.subtract(self.g_loss, mi, name='total_g_loss')
self.d_loss = tf.subtract(self.d_loss, mi, name='total_d_loss')
self.g_loss, self.d_loss = build_GAN_losses(vecpos, vecneg)
self.g_loss = tf.add(self.g_loss, MIloss, name='total_g_loss')
self.d_loss = tf.add(self.d_loss, MIloss, name='total_d_loss')
summary.add_moving_summary(MIloss, self.g_loss, self.d_loss, Hc, Elog_qc_given_x)
summary.add_moving_summary(self.g_loss, self.d_loss)
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
# distinguish between variables of generator and discriminator updates
self.collect_variables()
def get_gradient_processor_g(self):
return [CheckGradient(), ScaleGradient(('.*', 5), log=False)]
def get_data():
......@@ -105,7 +115,7 @@ def get_config():
lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig(
dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
]),
......@@ -120,18 +130,38 @@ def sample(model_path):
pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
model=Model(),
input_names=['zc'],
output_names=['gen/gen']))
input_names=['z_cat', 'z_uni_a', 'z_uni_b', 'z_noise'],
output_names=['gen/viz']))
# sample all one-hot encodings (10 times)
z_cat = np.tile(np.eye(10), [10, 1])
# sample continuos variables from -2 to +2 as mentioned in the paper
z_uni = np.linspace(-2.0, 2.0, num=100)
z_uni = z_uni[:, None]
IMG_SIZE = 400
eye = []
for k in np.eye(10):
eye = eye + [k] * 10
inputs = np.asarray(eye)
while True:
o = pred([inputs])
o = (o[0] + 1) * 128.0
viz = next(build_patch_list(o, nr_row=10, nr_col=10))
viz = cv2.resize(viz, (800, 800))
# only categorical turned on
z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM))
o = pred([z_cat, z_uni * 0, z_uni * 0, z_noise])[0]
viz1 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz1 = cv2.resize(viz1, (IMG_SIZE, IMG_SIZE))
# show effect of first continous variable with fixed noise
o = pred([z_cat, z_uni, z_uni * 0, z_noise * 0])[0]
viz2 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz2 = cv2.resize(viz2, (IMG_SIZE, IMG_SIZE))
# show effect of second continous variable with fixed noise
o = pred([z_cat, z_uni * 0, z_uni, z_noise * 0])[0]
viz3 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz3 = cv2.resize(viz3, (IMG_SIZE, IMG_SIZE))
viz = next(build_patch_list(
[viz1, viz2, viz3],
nr_row=1, nr_col=3, border=5, bgcolor=(255, 0, 0)))
interactive_imshow(viz)
......@@ -144,6 +174,7 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample:
BATCH = 100
sample(args.load)
else:
config = get_config()
......
......@@ -37,7 +37,7 @@ This is a visualization from tensorboard. Left to right: original, ground truth,
## InfoGAN-mnist.py
Reproduce one mnist experiement in InfoGAN.
By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information,
the network learns to map the 10 variables to 10 digits in a completely unsupervised way.
By assuming 10 latent variables corresponding to a categorical distribution, and 2 latent variables corresponding to an "uniform distributioN" and maximizing mutual information,
the network learns to map the 10 variables to 10 digits and the other two latent variables to rotation and thickness in a completely unsupervised way.
![infogan](demo/InfoGAN-mnist.jpg)
......@@ -19,7 +19,9 @@ __all__ = ['get_default_sess_config',
'restore_collection',
'clear_collection',
'freeze_collection',
'get_tf_version']
'get_tf_version',
'get_name_scope_name'
]
def get_default_sess_config(mem_fraction=0.99):
......@@ -165,3 +167,15 @@ def get_tf_version():
int:
"""
return int(tf.__version__.split('.')[1])
def get_name_scope_name():
"""
Returns:
str: the name of the current name scope, without the ending '/'.
"""
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
This diff is collapsed.
......@@ -18,7 +18,7 @@ except ImportError:
__all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list',
'pyplot_viz', 'dump_dataflow_images', 'intensity_to_rgb']
'pyplot_viz', 'dump_dataflow_images', 'intensity_to_rgb', 'stack_images']
def pyplot2img(plt):
......@@ -62,7 +62,7 @@ def minnone(x, y):
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
"""
Args:
img (np.ndarray): an image to show.
img (np.ndarray): an image (expect BGR) to show.
lclick_cb: a callback func(img, x, y) for left click event.
kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to
specify a callback func(img) for keypress.
......@@ -112,7 +112,8 @@ def build_patch_list(patch_list,
max_width(int), max_height(int): Maximum allowed size of the
visualization image. If ``nr_row/nr_col`` are not given, will use this to infer the rows and cols.
shuffle(bool): shuffle the images inside ``patch_list``.
bgcolor(int): background color in [0, 255].
bgcolor(int or 3-tuple): background color in [0, 255]. Either an int
or a BGR tuple.
viz(bool): whether to use :func:`interactive_imshow` to visualize the results.
lclick_cb: A callback function to get called when ``viz==True`` and an
image get clicked. It takes the image patch and its index in
......@@ -139,13 +140,21 @@ def build_patch_list(patch_list,
if nr_col is None:
nr_col = minnone(nr_col, max_width / (pw + border))
if isinstance(bgcolor, int):
bgchannel = 1
else:
bgchannel = 3
canvas_channel = max(patch_list.shape[3], bgchannel)
canvas = np.zeros((nr_row * (ph + border) - border,
nr_col * (pw + border) - border,
patch_list.shape[3]), dtype='uint8')
canvas_channel), dtype='uint8')
def draw_patch(plist):
cur_row, cur_col = 0, 0
canvas.fill(bgcolor)
if bgchannel == 1:
canvas.fill(bgcolor)
else:
canvas[:, :, :] = bgcolor
for patch in plist:
r0 = cur_row * (ph + border)
c0 = cur_col * (pw + border)
......@@ -267,6 +276,42 @@ def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
return intensity.astype('float32') * 255.0
def stack_images(imgs, vertical=False):
"""Stack images with different shapes and different number of channels.
Args:
imgs (np.array): imgage
vertical (bool, optional): stack images vertically
Returns:
np.array: stacked images
"""
rows = [x.shape[0] for x in imgs]
cols = [x.shape[1] for x in imgs]
if vertical:
if len(imgs[0].shape) == 2:
out = np.zeros((np.sum(rows), max(cols)), dtype='uint8')
else:
out = np.zeros((np.sum(rows), max(cols), 3), dtype='uint8')
else:
if len(imgs[0].shape) == 2:
out = np.zeros((max(rows), np.sum(cols)), dtype='uint8')
else:
out = np.zeros((max(rows), np.sum(cols), 3), dtype='uint8')
offset = 0
for i, img in enumerate(imgs):
assert img.max() > 1, "expect images within range [0, 255]"
if vertical:
out[offset:offset + rows[i], :cols[i]] = img
offset += rows[i]
else:
out[:rows[i], offset:offset + cols[i]] = img
offset += cols[i]
return out
if __name__ == '__main__':
imglist = []
for i in range(100):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment