Commit 9d6197aa authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

GAN refactor (#107)

* initial public refactor

* add uniform

* simplify code

* split uniform-factor

* fix prior

* working version

* more documentation

* clean up done for InfoGAN

* updated other examples to match changes in GAN.py

* final edits from my side

* docs changes in base class

* more docs and interface name

* some changes in name scope and docs

* use property for _dim

* visualization & gradient scale according to the paper

* update the demo img
parent 677af274
...@@ -17,7 +17,7 @@ from tensorpack.utils.viz import * ...@@ -17,7 +17,7 @@ from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
from tensorpack.utils.globvars import globalns as CFG, use_global_argument from tensorpack.utils.globvars import globalns as CFG, use_global_argument
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, build_GAN_losses from GAN import GANTrainer, RandomZData, GANModelDesc
""" """
DCGAN on CelebA dataset. DCGAN on CelebA dataset.
...@@ -35,7 +35,7 @@ CFG.BATCH = 128 ...@@ -35,7 +35,7 @@ CFG.BATCH = 128
CFG.Z_DIM = 100 CFG.Z_DIM = 100
class Model(ModelDesc): class Model(GANModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, (None, CFG.SHAPE, CFG.SHAPE, 3), 'input')] return [InputVar(tf.float32, (None, CFG.SHAPE, CFG.SHAPE, 3), 'input')]
...@@ -87,10 +87,8 @@ class Model(ModelDesc): ...@@ -87,10 +87,8 @@ class Model(ModelDesc):
with tf.variable_scope('discrim', reuse=True): with tf.variable_scope('discrim', reuse=True):
vecneg = self.discriminator(image_gen) vecneg = self.discriminator(image_gen)
self.g_loss, self.d_loss = build_GAN_losses(vecpos, vecneg) self.build_losses(vecpos, vecneg)
all_vars = tf.trainable_variables() self.collect_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
def get_data(): def get_data():
......
...@@ -7,11 +7,68 @@ import tensorflow as tf ...@@ -7,11 +7,68 @@ import tensorflow as tf
import numpy as np import numpy as np
import time import time
from tensorpack import (FeedfreeTrainerBase, TowerContext, from tensorpack import (FeedfreeTrainerBase, TowerContext,
get_global_step_var, QueueInput) get_global_step_var, QueueInput, ModelDesc)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
from tensorpack.tfutils.gradproc import apply_grad_processors, CheckGradient
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
class GANModelDesc(ModelDesc):
def collect_variables(self):
"""Extract variables by prefix
"""
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
def build_losses(self, logits_real, logits_fake):
"""D and G play two-player minimax game with value function V(G,D)
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Note, we swap 0, 1 labels as suggested in "Improving GANs".
Args:
logits_real (tf.Tensor): discrim logits from real samples
logits_fake (tf.Tensor): discrim logits from fake samples produced by generator
Returns:
tf.Tensor: Description
"""
with tf.name_scope("GAN_loss"):
score_real = tf.sigmoid(logits_real)
score_fake = tf.sigmoid(logits_fake)
tf.summary.histogram('score-real', score_real)
tf.summary.histogram('score-fake', score_fake)
with tf.name_scope("discrim"):
d_loss_pos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_real, labels=tf.zeros_like(logits_real)), name='loss_real')
d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.ones_like(logits_fake)), name='loss_fake')
d_pos_acc = tf.reduce_mean(tf.cast(score_real < 0.5, tf.float32), name='accuracy_real')
d_neg_acc = tf.reduce_mean(tf.cast(score_fake > 0.5, tf.float32), name='accuracy_fake')
self.d_accuracy = tf.add(.5 * d_pos_acc, .5 * d_neg_acc, name='accuracy')
self.d_loss = tf.add(.5 * d_loss_pos, .5 * d_loss_neg, name='loss')
with tf.name_scope("gen"):
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.zeros_like(logits_fake)), name='loss')
self.g_accuracy = tf.reduce_mean(tf.cast(score_fake < 0.5, tf.float32), name='accuracy')
add_moving_summary(self.g_loss, self.d_loss, self.d_accuracy, self.g_accuracy)
def get_gradient_processor_g(self):
return [CheckGradient()]
def get_gradient_processor_d(self):
return [CheckGradient()]
class GANTrainer(FeedfreeTrainerBase): class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config): def __init__(self, config):
self._input_method = QueueInput(config.dataflow) self._input_method = QueueInput(config.dataflow)
...@@ -22,11 +79,18 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -22,11 +79,18 @@ class GANTrainer(FeedfreeTrainerBase):
with TowerContext(''): with TowerContext(''):
actual_inputs = self._get_input_tensors() actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs) self.model.build_graph(actual_inputs)
self.g_min = self.config.optimizer.minimize(self.model.g_loss, grads = self.config.optimizer.compute_gradients(
var_list=self.model.g_vars, name='g_op') self.model.g_loss, var_list=self.model.g_vars)
grads = apply_grad_processors(
grads, self.model.get_gradient_processor_g())
self.g_min = self.config.optimizer.apply_gradients(grads, name='g_op')
with tf.control_dependencies([self.g_min]): with tf.control_dependencies([self.g_min]):
self.d_min = self.config.optimizer.minimize(self.model.d_loss, grads = self.config.optimizer.compute_gradients(
var_list=self.model.d_vars, name='d_op') self.model.d_loss, var_list=self.model.d_vars)
grads = apply_grad_processors(
grads, self.model.get_gradient_processor_d())
self.d_min = self.config.optimizer.apply_gradients(grads, name='d_op')
self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr') self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr')
self.summary_op = summary_moving_average() self.summary_op = summary_moving_average()
self.train_op = tf.group(self.d_min, self.summary_op, self.gs_incr) self.train_op = tf.group(self.d_min, self.summary_op, self.gs_incr)
...@@ -44,31 +108,3 @@ class RandomZData(DataFlow): ...@@ -44,31 +108,3 @@ class RandomZData(DataFlow):
def get_data(self): def get_data(self):
while True: while True:
yield [np.random.uniform(-1, 1, size=self.shape)] yield [np.random.uniform(-1, 1, size=self.shape)]
def build_GAN_losses(vecpos, vecneg):
"""
:param vecpos, vecneg: output of the discriminator (logits) for real
and fake images.
:return: (loss of G, loss of D)
"""
sigmpos = tf.sigmoid(vecpos)
sigmneg = tf.sigmoid(vecneg)
tf.summary.histogram('sigmoid-pos', sigmpos)
tf.summary.histogram('sigmoid-neg', sigmneg)
d_loss_pos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=vecpos, labels=tf.ones_like(vecpos)), name='d_CE_loss_pos')
d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=vecneg, labels=tf.zeros_like(vecneg)), name='d_CE_loss_neg')
d_pos_acc = tf.reduce_mean(tf.cast(sigmpos > 0.5, tf.float32), name='pos_acc')
d_neg_acc = tf.reduce_mean(tf.cast(sigmneg < 0.5, tf.float32), name='neg_acc')
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=vecneg, labels=tf.ones_like(vecneg)), name='g_CE_loss')
d_loss = tf.add(d_loss_pos, d_loss_neg, name='d_CE_loss')
add_moving_summary(d_loss_pos, d_loss_neg,
g_loss, d_loss,
d_pos_acc, d_neg_acc)
return g_loss, d_loss
...@@ -16,7 +16,7 @@ from tensorpack import * ...@@ -16,7 +16,7 @@ from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average from tensorpack.tfutils.summary import add_moving_summary, summary_moving_average
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, build_GAN_losses from GAN import GANTrainer, GANModelDesc
""" """
To train: To train:
...@@ -42,7 +42,7 @@ LAMBDA = 100 ...@@ -42,7 +42,7 @@ LAMBDA = 100
NF = 64 # number of filter NF = 64 # number of filter
class Model(ModelDesc): class Model(GANModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'), return [InputVar(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'),
...@@ -114,7 +114,7 @@ class Model(ModelDesc): ...@@ -114,7 +114,7 @@ class Model(ModelDesc):
with tf.variable_scope('discrim', reuse=True): with tf.variable_scope('discrim', reuse=True):
fake_pred = self.discriminator(input, fake_output) fake_pred = self.discriminator(input, fake_output)
self.g_loss, self.d_loss = build_GAN_losses(real_pred, fake_pred) self.build_losses(real_pred, fake_pred)
errL1 = tf.reduce_mean(tf.abs(fake_output - output), name='L1_loss') errL1 = tf.reduce_mean(tf.abs(fake_output - output), name='L1_loss')
self.g_loss = tf.add(self.g_loss, LAMBDA * errL1, name='total_g_loss') self.g_loss = tf.add(self.g_loss, LAMBDA * errL1, name='total_g_loss')
add_moving_summary(errL1, self.g_loss) add_moving_summary(errL1, self.g_loss)
...@@ -129,9 +129,7 @@ class Model(ModelDesc): ...@@ -129,9 +129,7 @@ class Model(ModelDesc):
viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz') viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
tf.summary.image('input,output,fake', viz, max_outputs=max(30, BATCH)) tf.summary.image('input,output,fake', viz, max_outputs=max(30, BATCH))
all_vars = tf.trainable_variables() self.collect_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
def split_input(img): def split_input(img):
......
...@@ -12,13 +12,16 @@ import argparse ...@@ -12,13 +12,16 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.distributions import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, build_GAN_losses from tensorpack.tfutils.gradproc import ScaleGradient, CheckGradient
from GAN import GANTrainer, GANModelDesc
BATCH = 128 BATCH = 128
NOISE_DIM = 62
class Model(ModelDesc): class Model(GANModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, (None, 28, 28), 'input')] return [InputVar(tf.float32, (None, 28, 28), 'input')]
...@@ -29,13 +32,12 @@ class Model(ModelDesc): ...@@ -29,13 +32,12 @@ class Model(ModelDesc):
l = tf.reshape(l, [-1, 7, 7, 128]) l = tf.reshape(l, [-1, 7, 7, 128])
l = Deconv2D('deconv1', l, [14, 14, 64], 4, 2, nl=BNReLU) l = Deconv2D('deconv1', l, [14, 14, 64], 4, 2, nl=BNReLU)
l = Deconv2D('deconv2', l, [28, 28, 1], 4, 2, nl=tf.identity) l = Deconv2D('deconv2', l, [28, 28, 1], 4, 2, nl=tf.identity)
l = tf.nn.tanh(l, name='gen') l = tf.tanh(l, name='gen')
return l return l
def discriminator(self, imgs): def discriminator(self, imgs):
""" return a (b, 1) logits"""
with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \ with argscope(Conv2D, nl=tf.identity, kernel_shape=4, stride=2), \
argscope(LeakyReLU, alpha=0.2): argscope(LeakyReLU, alpha=0.1):
l = (LinearWrap(imgs) l = (LinearWrap(imgs)
.Conv2D('conv0', 64) .Conv2D('conv0', 64)
.LeakyReLU() .LeakyReLU()
...@@ -48,49 +50,57 @@ class Model(ModelDesc): ...@@ -48,49 +50,57 @@ class Model(ModelDesc):
encoder = (LinearWrap(l) encoder = (LinearWrap(l)
.FullyConnected('fce1', 128, nl=tf.identity) .FullyConnected('fce1', 128, nl=tf.identity)
.BatchNorm('bne').LeakyReLU() .BatchNorm('bne').LeakyReLU()
.FullyConnected('fce-out', 10, nl=tf.identity)()) .FullyConnected('fce-out', self.factors.param_dim, nl=tf.identity)())
return logits, encoder return logits, encoder
def _build_graph(self, input_vars): def _build_graph(self, input_vars):
image_pos = input_vars[0] real_sample = input_vars[0]
image_pos = tf.expand_dims(image_pos * 2.0 - 1, -1) real_sample = tf.expand_dims(real_sample * 2.0 - 1, -1)
prior_prob = tf.constant([0.1] * 10, name='prior_prob') # latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
# assume first 10 is categorical self.factors = ProductDistribution("factors", [CategoricalDistribution("cat", 10),
ids = tf.multinomial(tf.zeros([BATCH, 10]), num_samples=1)[:, 0] GaussianDistributionUniformPrior("uni_a", 1),
zc = tf.one_hot(ids, 10, name='zc_train') GaussianDistributionUniformPrior("uni_b", 1),
zc = tf.placeholder_with_default(zc, [None, 10], name='zc') NoiseDistribution("noise", NOISE_DIM)])
z = tf.random_uniform(tf.stack([tf.shape(zc)[0], 90]), -1, 1, name='z_train') z = self.factors.sample_prior(BATCH, name='zc')
z = tf.placeholder_with_default(z, [None, 90], name='z')
z = tf.concat_v2([zc, z], 1, name='fullz')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): W_init=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
image_gen = self.generator(z) fake_sample = self.generator(z)
tf.summary.image('gen', image_gen, max_outputs=30) fake_sample_viz = tf.cast((fake_sample + 1) * 128.0, tf.uint8, name='viz')
tf.summary.image('gen', fake_sample_viz, max_outputs=30)
# TODO investigate how bn stats should be updated across two discrim
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
vecpos, _ = self.discriminator(image_pos) real_pred, _ = self.discriminator(real_sample)
with tf.variable_scope('discrim', reuse=True): with tf.variable_scope('discrim', reuse=True):
vecneg, dist_param = self.discriminator(image_gen) fake_pred, dist_param = self.discriminator(fake_sample)
logprob = tf.nn.log_softmax(dist_param) # log prob of each category
# post-process all dist_params from discriminator
encoder_activation = self.factors.encoder_activation(dist_param)
with tf.name_scope("mutual_information"):
MIs = self.factors.mutual_information(z, encoder_activation)
mi = tf.add_n(MIs, name="total")
summary.add_moving_summary(MIs + [mi])
# default GAN objective
self.build_losses(real_pred, fake_pred)
# Q(c|x) = Q(zc | image_gen) # subtract mutual information for latent factores (we want to maximize them)
log_qc_given_x = tf.reduce_sum(logprob * zc, 1, name='logQc_x') # bx1 self.g_loss = tf.subtract(self.g_loss, mi, name='total_g_loss')
log_qc = tf.reduce_sum(prior_prob * zc, 1, name='logQc') self.d_loss = tf.subtract(self.d_loss, mi, name='total_d_loss')
Elog_qc_given_x = tf.reduce_mean(log_qc_given_x, name='ElogQc_x')
Hc = tf.reduce_mean(-log_qc, name='Hc')
MIloss = tf.multiply(Hc + Elog_qc_given_x, -1.0, name='neg_MI')
self.g_loss, self.d_loss = build_GAN_losses(vecpos, vecneg) summary.add_moving_summary(self.g_loss, self.d_loss)
self.g_loss = tf.add(self.g_loss, MIloss, name='total_g_loss')
self.d_loss = tf.add(self.d_loss, MIloss, name='total_d_loss')
summary.add_moving_summary(MIloss, self.g_loss, self.d_loss, Hc, Elog_qc_given_x)
all_vars = tf.trainable_variables() # distinguish between variables of generator and discriminator updates
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')] self.collect_variables()
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
def get_gradient_processor_g(self):
return [CheckGradient(), ScaleGradient(('.*', 5), log=False)]
def get_data(): def get_data():
...@@ -105,7 +115,7 @@ def get_config(): ...@@ -105,7 +115,7 @@ def get_config():
lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True) lr = symbf.get_scalar_var('learning_rate', 2e-4, summary=True)
return TrainConfig( return TrainConfig(
dataflow=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), ModelSaver(),
]), ]),
...@@ -120,18 +130,38 @@ def sample(model_path): ...@@ -120,18 +130,38 @@ def sample(model_path):
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
model=Model(), model=Model(),
input_names=['zc'], input_names=['z_cat', 'z_uni_a', 'z_uni_b', 'z_noise'],
output_names=['gen/gen'])) output_names=['gen/viz']))
# sample all one-hot encodings (10 times)
z_cat = np.tile(np.eye(10), [10, 1])
# sample continuos variables from -2 to +2 as mentioned in the paper
z_uni = np.linspace(-2.0, 2.0, num=100)
z_uni = z_uni[:, None]
IMG_SIZE = 400
eye = []
for k in np.eye(10):
eye = eye + [k] * 10
inputs = np.asarray(eye)
while True: while True:
o = pred([inputs]) # only categorical turned on
o = (o[0] + 1) * 128.0 z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM))
viz = next(build_patch_list(o, nr_row=10, nr_col=10)) o = pred([z_cat, z_uni * 0, z_uni * 0, z_noise])[0]
viz = cv2.resize(viz, (800, 800)) viz1 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz1 = cv2.resize(viz1, (IMG_SIZE, IMG_SIZE))
# show effect of first continous variable with fixed noise
o = pred([z_cat, z_uni, z_uni * 0, z_noise * 0])[0]
viz2 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz2 = cv2.resize(viz2, (IMG_SIZE, IMG_SIZE))
# show effect of second continous variable with fixed noise
o = pred([z_cat, z_uni * 0, z_uni, z_noise * 0])[0]
viz3 = next(build_patch_list(o, nr_row=10, nr_col=10))
viz3 = cv2.resize(viz3, (IMG_SIZE, IMG_SIZE))
viz = next(build_patch_list(
[viz1, viz2, viz3],
nr_row=1, nr_col=3, border=5, bgcolor=(255, 0, 0)))
interactive_imshow(viz) interactive_imshow(viz)
...@@ -144,6 +174,7 @@ if __name__ == '__main__': ...@@ -144,6 +174,7 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample: if args.sample:
BATCH = 100
sample(args.load) sample(args.load)
else: else:
config = get_config() config = get_config()
......
...@@ -37,7 +37,7 @@ This is a visualization from tensorboard. Left to right: original, ground truth, ...@@ -37,7 +37,7 @@ This is a visualization from tensorboard. Left to right: original, ground truth,
## InfoGAN-mnist.py ## InfoGAN-mnist.py
Reproduce one mnist experiement in InfoGAN. Reproduce one mnist experiement in InfoGAN.
By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information, By assuming 10 latent variables corresponding to a categorical distribution, and 2 latent variables corresponding to an "uniform distributioN" and maximizing mutual information,
the network learns to map the 10 variables to 10 digits in a completely unsupervised way. the network learns to map the 10 variables to 10 digits and the other two latent variables to rotation and thickness in a completely unsupervised way.
![infogan](demo/InfoGAN-mnist.jpg) ![infogan](demo/InfoGAN-mnist.jpg)
...@@ -19,7 +19,9 @@ __all__ = ['get_default_sess_config', ...@@ -19,7 +19,9 @@ __all__ = ['get_default_sess_config',
'restore_collection', 'restore_collection',
'clear_collection', 'clear_collection',
'freeze_collection', 'freeze_collection',
'get_tf_version'] 'get_tf_version',
'get_name_scope_name'
]
def get_default_sess_config(mem_fraction=0.99): def get_default_sess_config(mem_fraction=0.99):
...@@ -165,3 +167,15 @@ def get_tf_version(): ...@@ -165,3 +167,15 @@ def get_tf_version():
int: int:
""" """
return int(tf.__version__.split('.')[1]) return int(tf.__version__.split('.')[1])
def get_name_scope_name():
"""
Returns:
str: the name of the current name scope, without the ending '/'.
"""
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
import tensorflow as tf
from functools import wraps
import numpy as np
from ..utils import logger
from ..tfutils import get_name_scope_name
__all__ = ['Distribution',
'CategoricalDistribution', 'GaussianDistributionUniformPrior',
'NoiseDistribution', 'ProductDistribution']
# TODO encoder_activation and the ProductDistribution class brings many redundant concat and split
def class_scope(func):
"""
A decorator which wraps a function with a name_scope: "{class_name}_{method_name}".
The "{class_name}" is either ``cls.name`` or simply the class name.
It helps enhance TensorBoard graph visualization by grouping operators.
This is just syntatic sugar to prevent wrinting: with
``tf.name_scope(...)`` in each method.
"""
@wraps(func)
def _impl(self, *args, **kwargs):
# is there a specific name?
distr_name = self.name
if distr_name is None:
distr_name = self.__class__.__name__
# scope it only when it is not already scoped with current class
if distr_name not in get_name_scope_name():
with tf.name_scope(distr_name + "_" + func.__name__):
return func(self, *args, **kwargs)
else:
return func(self, *args, **kwargs)
return _impl
class Distribution(object):
"""
Base class of symbolic distribution utilities
(the distrbution parameters can be symbolic tensors).
"""
name = None
def __init__(self, name):
"""
Args:
name(str): the name to be used for scope and tensors in this
distribution.
"""
self.name = name
@class_scope
def loglikelihood(self, x, theta):
"""
Args:
x: samples of shape (batch, sample_dim)
theta: model parameters of shape (batch, param_dim)
Returns:
log likelihood of each sample, of shape (batch,)
"""
assert x.get_shape().ndims == 2 and \
x.get_shape()[1] == self.sample_dim, \
x.get_shape()
assert theta.get_shape().ndims == 2 and \
theta.get_shape()[1] == self.param_dim, \
theta.get_shape()
ret = self._loglikelihood(x, theta)
assert ret.get_shape().ndims == 1, ret.get_shape()
return ret
@class_scope
def loglikelihood_prior(self, x):
"""likelihood from prior for this distribution
Args:
x: samples of shape (batch, sample_dim)
Returns:
a symbolic vector containing loglikelihood of each sample,
using prior of this distribution.
"""
assert x.get_shape().ndims == 2 and \
x.get_shape()[1] == self.sample_dim, \
x.get_shape()
batch_size = x.get_shape().as_list()[0]
s = self.prior(batch_size)
return self._loglikelihood(x, s)
@class_scope
def mutual_information(self, x, theta):
"""
Approximates mutual information between x and some information s.
Here we return a variational lower bound of the mutual information,
assuming a proposal distribution Q(x|s) (which approximates P(x|s) )
has the form of this distribution parameterized by theta.
.. math::
I(x;s) = H(x) - H(x|s)
= H(x) + E[\log P(x|s)]
\\ge H(x) + E_{x \sim P(x|s)}[\log Q(x|s)]
Args:
x: samples of shape (batch, sample_dim)
theta: parameters defining the proposal distribution Q. shape (batch, param_dim).
Returns:
lower-bounded mutual information, a scalar tensor.
"""
entr = self.prior_entropy(x)
cross_entr = self.entropy(x, theta)
return tf.subtract(entr, cross_entr, name="mutual_information")
@class_scope
def prior_entropy(self, x):
r"""
Estimated entropy of the prior distribution,
from a batch of samples (as average). It
estimates the likelihood of samples using the prior distribution.
.. math::
H(x) = -E[\log p(x_i)], \text{where } p \text{ is the prior}
Args:
x: samples of shape (batch, sample_dim)
Returns:
a scalar, estimated entropy.
"""
return tf.reduce_mean(-self.loglikelihood_prior(x), name="prior_entropy")
@class_scope
def entropy(self, x, theta):
r""" Entropy of this distribution parameterized by theta,
esimtated from a batch of samples.
.. math::
H(x) = - E[\log p(x_i)], \text{where } p \text{ is parameterized by } \theta.
Args:
x: samples of shape (batch, sample_dim)
theta: model parameters of shape (batch, param_dim)
Returns:
a scalar tensor, the entropy.
"""
return tf.reduce_mean(-self.loglikelihood(x, theta), name="entropy")
@class_scope
def prior(self, batch_size):
"""Get the prior parameters of this distribution.
Returns:
a (batch, param_dim) 2D tensor, containing priors of
this distribution repeated for batch_size times.
"""
return self._prior(batch_size)
@class_scope
def encoder_activation(self, dist_param):
""" An activation function to produce
feasible distribution parameters from unconstrained raw network output.
Args:
dist_param: output from a network, of shape (batch, param_dim).
Returns:
a tensor of the same shape, the distribution parameters.
"""
return self._encoder_activation(dist_param)
def sample_prior(self, batch_size):
"""
Sample a batch of data with the prior distribution.
Args:
batch_size(int):
Returns:
samples of shape (batch, sample_dim)
"""
s = self._sample_prior(batch_size)
return s
@property
def param_dim(self):
"""
Returns:
int: the dimension of parameters of this distribution.
"""
raise NotImplementedError
@property
def sample_dim(self):
"""
Returns:
int: the dimension of samples out of this distribution.
"""
raise NotImplementedError
def _loglikelihood(self, x, theta):
raise NotImplementedError
def _prior(self, batch_size):
raise NotImplementedError
def _sample_prior(self, batch_size):
raise NotImplementedError
def _encoder_activation(self, dist_param):
return dist_param
class CategoricalDistribution(Distribution):
""" Categorical distribution of a set of classes.
Each sample is a one-hot vector.
"""
def __init__(self, name, cardinality):
"""
Args:
cardinality (int): number of categories
"""
super(CategoricalDistribution, self).__init__(name)
self.cardinality = cardinality
def _loglikelihood(self, x, theta):
eps = 1e-8
return tf.reduce_sum(tf.log(theta + eps) * x, reduction_indices=1)
def _prior(self, batch_size):
return tf.constant(1.0 / self.cardinality,
tf.float32, [batch_size, self.cardinality])
def _sample_prior(self, batch_size):
ids = tf.multinomial(tf.zeros([batch_size, self.cardinality]), num_samples=1)[:, 0]
ret = tf.one_hot(ids, self.cardinality)
return ret
def _encoder_activation(self, dist_param):
return tf.nn.softmax(dist_param)
@property
def param_dim(self):
return self.cardinality
@property
def sample_dim(self):
return self.cardinality
class GaussianDistributionUniformPrior(Distribution):
"""Gaussian distribution with prior U(-1,1).
It implements a Gaussian with uniform :meth:`sample_prior` method.
"""
def __init__(self, name, dim, fixed_std=True):
"""
Args:
dim(int): the dimension of samples.
fixed_std (bool): if True, will use 1 as std for all dimensions.
"""
super(GaussianDistributionUniformPrior, self).__init__(name)
self.dim = dim
self.fixed_std = fixed_std
def _loglikelihood(self, x, theta):
eps = 1e-8
if self.fixed_std:
mean = theta
stddev = tf.ones_like(mean)
exponent = (x - mean)
else:
mean, stddev = tf.split(theta, 2, axis=1)
exponent = (x - mean) / (stddev + eps)
return tf.reduce_sum(
- 0.5 * np.log(2 * np.pi) - tf.log(stddev + eps) - 0.5 * tf.square(exponent),
reduction_indices=1
)
def _prior(self, batch_size):
if self.fixed_std:
return tf.zeros([batch_size, self.param_dim])
else:
return tf.concat_v2([tf.zeros([batch_size, self.param_dim]),
tf.ones([batch_size, self.param_dim])], 1)
def _sample_prior(self, batch_size):
return tf.random_uniform([batch_size, self.dim], -1, 1)
def _encoder_activation(self, dist_param):
if self.fixed_std:
return dist_param
else:
mean, stddev = tf.split(dist_param, 2, axis=1)
# this is from https://github.com/openai/InfoGAN. don't know why
stddev = tf.sqrt(tf.exp(stddev))
return tf.concat_v2([mean, stddev], axis=1)
@property
def param_dim(self):
if self.fixed_std:
return self.dim
else:
return 2 * self.dim
@property
def sample_dim(self):
return self.dim
class NoiseDistribution(Distribution):
"""This is not really a distribution.
It is the uniform noise input of GAN which shares interface with Distribution, to
simplify implementation of GAN.
"""
def __init__(self, name, dim):
"""
Args:
dim(int): the dimension of the noise.
"""
# TODO more options, e.g. use gaussian or uniform?
super(NoiseDistribution, self).__init__(name)
self.dim = dim
def _loglikelihood(self, x, theta):
return 0
def _prior(self):
return 0
def _sample_prior(self, batch_size):
zc = tf.random_uniform([batch_size, self.dim], -1, 1)
return zc
def _encoder_activation(self, dist_param):
return 0
@property
def param_dim(self):
return 0
@property
def sample_dim(self):
return self.dim
class ProductDistribution(Distribution):
"""A product of a list of independent distributions. """
def __init__(self, name, dists):
"""
Args:
dists(list): list of :class:`Distribution`.
"""
super(ProductDistribution, self).__init__(name)
self.dists = dists
@property
def param_dim(self):
return np.sum([d.param_dim for d in self.dists])
def _splitter(self, s, param):
"""Input is split into a list of chunks according
to dist.param_dim along axis=1
Args:
s (tf.Tensor): batch of vectors with shape (batch, param_dim or sample_dim)
param (bool): split params, otherwise split samples
Yields:
tf.Tensor: chunk from input of length N_i with sum N_i = N
"""
offset = 0
for dist in self.dists:
if param:
off = dist.param_dim
else:
off = dist.sample_dim
yield s[:, offset:offset + off]
offset += off
def mutual_information(self, x, theta):
"""
Return mutual information of all distributions but skip the
unparameterized ones.
Note:
It returns a list, as one might use different weights for each
distribution.
Returns:
list[tf.Tensor]: mutual informations of each distribution.
"""
MIs = [] # noqa
for dist, xi, ti in zip(self.dists,
self._splitter(x, False),
self._splitter(theta, True)):
if dist.param_dim > 0:
MIs.append(dist.mutual_information(xi, ti))
return MIs
def sample_prior(self, batch_size, name='sample_prior'):
"""
Concat the samples from all distributions.
Returns:
tf.Tensor: a tensor of shape (batch, sample_dim), but first dimension is statically unknown,
allowing you to do inference with custom batch size.
"""
samples = []
for k, dist in enumerate(self.dists):
init = dist._sample_prior(batch_size)
plh = tf.placeholder_with_default(init, [batch_size, dist.sample_dim], name='z_' + dist.name)
samples.append(plh)
logger.info("Placeholder for %s(%s) is %s " % (dist.name, dist.__class__.__name__, plh.name[:-2]))
return tf.concat_v2(samples, 1, name=name)
def _encoder_activation(self, dist_params):
rsl = []
for dist, dist_param in zip(self.dists, self._splitter(dist_params, True)):
if dist.param_dim > 0:
rsl.append(dist._encoder_activation(dist_param))
return tf.concat_v2(rsl, 1)
...@@ -18,7 +18,7 @@ except ImportError: ...@@ -18,7 +18,7 @@ except ImportError:
__all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list', __all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list',
'pyplot_viz', 'dump_dataflow_images', 'intensity_to_rgb'] 'pyplot_viz', 'dump_dataflow_images', 'intensity_to_rgb', 'stack_images']
def pyplot2img(plt): def pyplot2img(plt):
...@@ -62,7 +62,7 @@ def minnone(x, y): ...@@ -62,7 +62,7 @@ def minnone(x, y):
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs): def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
""" """
Args: Args:
img (np.ndarray): an image to show. img (np.ndarray): an image (expect BGR) to show.
lclick_cb: a callback func(img, x, y) for left click event. lclick_cb: a callback func(img, x, y) for left click event.
kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to
specify a callback func(img) for keypress. specify a callback func(img) for keypress.
...@@ -112,7 +112,8 @@ def build_patch_list(patch_list, ...@@ -112,7 +112,8 @@ def build_patch_list(patch_list,
max_width(int), max_height(int): Maximum allowed size of the max_width(int), max_height(int): Maximum allowed size of the
visualization image. If ``nr_row/nr_col`` are not given, will use this to infer the rows and cols. visualization image. If ``nr_row/nr_col`` are not given, will use this to infer the rows and cols.
shuffle(bool): shuffle the images inside ``patch_list``. shuffle(bool): shuffle the images inside ``patch_list``.
bgcolor(int): background color in [0, 255]. bgcolor(int or 3-tuple): background color in [0, 255]. Either an int
or a BGR tuple.
viz(bool): whether to use :func:`interactive_imshow` to visualize the results. viz(bool): whether to use :func:`interactive_imshow` to visualize the results.
lclick_cb: A callback function to get called when ``viz==True`` and an lclick_cb: A callback function to get called when ``viz==True`` and an
image get clicked. It takes the image patch and its index in image get clicked. It takes the image patch and its index in
...@@ -139,13 +140,21 @@ def build_patch_list(patch_list, ...@@ -139,13 +140,21 @@ def build_patch_list(patch_list,
if nr_col is None: if nr_col is None:
nr_col = minnone(nr_col, max_width / (pw + border)) nr_col = minnone(nr_col, max_width / (pw + border))
if isinstance(bgcolor, int):
bgchannel = 1
else:
bgchannel = 3
canvas_channel = max(patch_list.shape[3], bgchannel)
canvas = np.zeros((nr_row * (ph + border) - border, canvas = np.zeros((nr_row * (ph + border) - border,
nr_col * (pw + border) - border, nr_col * (pw + border) - border,
patch_list.shape[3]), dtype='uint8') canvas_channel), dtype='uint8')
def draw_patch(plist): def draw_patch(plist):
cur_row, cur_col = 0, 0 cur_row, cur_col = 0, 0
if bgchannel == 1:
canvas.fill(bgcolor) canvas.fill(bgcolor)
else:
canvas[:, :, :] = bgcolor
for patch in plist: for patch in plist:
r0 = cur_row * (ph + border) r0 = cur_row * (ph + border)
c0 = cur_col * (pw + border) c0 = cur_col * (pw + border)
...@@ -267,6 +276,42 @@ def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False): ...@@ -267,6 +276,42 @@ def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
return intensity.astype('float32') * 255.0 return intensity.astype('float32') * 255.0
def stack_images(imgs, vertical=False):
"""Stack images with different shapes and different number of channels.
Args:
imgs (np.array): imgage
vertical (bool, optional): stack images vertically
Returns:
np.array: stacked images
"""
rows = [x.shape[0] for x in imgs]
cols = [x.shape[1] for x in imgs]
if vertical:
if len(imgs[0].shape) == 2:
out = np.zeros((np.sum(rows), max(cols)), dtype='uint8')
else:
out = np.zeros((np.sum(rows), max(cols), 3), dtype='uint8')
else:
if len(imgs[0].shape) == 2:
out = np.zeros((max(rows), np.sum(cols)), dtype='uint8')
else:
out = np.zeros((max(rows), np.sum(cols), 3), dtype='uint8')
offset = 0
for i, img in enumerate(imgs):
assert img.max() > 1, "expect images within range [0, 255]"
if vertical:
out[offset:offset + rows[i], :cols[i]] = img
offset += rows[i]
else:
out[:rows[i], offset:offset + cols[i]] = img
offset += cols[i]
return out
if __name__ == '__main__': if __name__ == '__main__':
imglist = [] imglist = []
for i in range(100): for i in range(100):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment