Commit 7a008efe authored by Yuxin Wu's avatar Yuxin Wu

swap back 0,1 labels of GAN (#107)

parent 14864868
...@@ -16,18 +16,18 @@ matrix: ...@@ -16,18 +16,18 @@ matrix:
include: include:
- os: linux - os: linux
python: 2.7 python: 2.7
env: TF_VERSION=1.0.0rc1 TF_TYPE=release env: TF_VERSION=1.0.0rc2 TF_TYPE=release
- os: linux - os: linux
python: 3.5 python: 3.5
env: TF_VERSION=1.0.0rc1 TF_TYPE=release env: TF_VERSION=1.0.0rc2 TF_TYPE=release
- os: linux - os: linux
python: 2.7 python: 2.7
env: TF_VERSION=1.0.0rc1 TF_TYPE=nightly env: TF_VERSION=1.0.0rc2 TF_TYPE=nightly
- os: linux - os: linux
python: 3.5 python: 3.5
env: TF_VERSION=1.0.0rc1 TF_TYPE=nightly env: TF_VERSION=1.0.0rc2 TF_TYPE=nightly
allow_failures: allow_failures:
- env: TF_VERSION=1.0.0rc1 TF_TYPE=nightly - env: TF_TYPE=nightly
install: install:
- pip install -U pip # the pip version on travis is too old - pip install -U pip # the pip version on travis is too old
......
...@@ -22,6 +22,7 @@ from GAN import GANTrainer, RandomZData, GANModelDesc ...@@ -22,6 +22,7 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
DCGAN on CelebA dataset. DCGAN on CelebA dataset.
1. Download the 'aligned&cropped' version of CelebA dataset 1. Download the 'aligned&cropped' version of CelebA dataset
from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
(or just use any directory of jpg files). (or just use any directory of jpg files).
2. Start training: 2. Start training:
./DCGAN-CelebA.py --data /path/to/image_align_celeba/ ./DCGAN-CelebA.py --data /path/to/image_align_celeba/
...@@ -39,7 +40,7 @@ class Model(GANModelDesc): ...@@ -39,7 +40,7 @@ class Model(GANModelDesc):
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'input')] return [InputDesc(tf.float32, (None, SHAPE, SHAPE, 3), 'input')]
def generator(self, z): def generator(self, z):
""" return a image generated from z""" """ return an image generated from z"""
nf = 64 nf = 64
l = FullyConnected('fc0', z, nf * 8 * 4 * 4, nl=tf.identity) l = FullyConnected('fc0', z, nf * 8 * 4 * 4, nl=tf.identity)
l = tf.reshape(l, [-1, 4, 4, nf * 8]) l = tf.reshape(l, [-1, 4, 4, nf * 8])
...@@ -79,7 +80,7 @@ class Model(GANModelDesc): ...@@ -79,7 +80,7 @@ class Model(GANModelDesc):
W_init=tf.truncated_normal_initializer(stddev=0.02)): W_init=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'): with tf.variable_scope('gen'):
image_gen = self.generator(z) image_gen = self.generator(z)
tf.summary.image('gen', image_gen, max_outputs=30) tf.summary.image('generated-samples', image_gen, max_outputs=30)
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
vecpos = self.discriminator(image_pos) vecpos = self.discriminator(image_pos)
with tf.variable_scope('discrim', reuse=True): with tf.variable_scope('discrim', reuse=True):
...@@ -106,10 +107,8 @@ def get_data(): ...@@ -106,10 +107,8 @@ def get_data():
def get_config(): def get_config():
logger.auto_set_dir()
dataset = get_data()
return TrainConfig( return TrainConfig(
dataflow=dataset, dataflow=get_data(),
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=Model(), model=Model(),
...@@ -145,6 +144,7 @@ if __name__ == '__main__': ...@@ -145,6 +144,7 @@ if __name__ == '__main__':
sample(args.load) sample(args.load)
else: else:
assert args.data assert args.data
logger.auto_set_dir()
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
...@@ -31,14 +31,9 @@ class GANModelDesc(ModelDesc): ...@@ -31,14 +31,9 @@ class GANModelDesc(ModelDesc):
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))] min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Note, we swap 0, 1 labels as suggested in "Improving GANs".
Args: Args:
logits_real (tf.Tensor): discrim logits from real samples logits_real (tf.Tensor): discrim logits from real samples
logits_fake (tf.Tensor): discrim logits from fake samples produced by generator logits_fake (tf.Tensor): discrim logits from fake samples produced by generator
Returns:
tf.Tensor: Description
""" """
with tf.name_scope("GAN_loss"): with tf.name_scope("GAN_loss"):
score_real = tf.sigmoid(logits_real) score_real = tf.sigmoid(logits_real)
...@@ -48,20 +43,20 @@ class GANModelDesc(ModelDesc): ...@@ -48,20 +43,20 @@ class GANModelDesc(ModelDesc):
with tf.name_scope("discrim"): with tf.name_scope("discrim"):
d_loss_pos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( d_loss_pos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_real, labels=tf.zeros_like(logits_real)), name='loss_real') logits=logits_real, labels=tf.ones_like(logits_real)), name='loss_real')
d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.ones_like(logits_fake)), name='loss_fake') logits=logits_fake, labels=tf.zeros_like(logits_fake)), name='loss_fake')
d_pos_acc = tf.reduce_mean(tf.cast(score_real < 0.5, tf.float32), name='accuracy_real') d_pos_acc = tf.reduce_mean(tf.cast(score_real > 0.5, tf.float32), name='accuracy_real')
d_neg_acc = tf.reduce_mean(tf.cast(score_fake > 0.5, tf.float32), name='accuracy_fake') d_neg_acc = tf.reduce_mean(tf.cast(score_fake < 0.5, tf.float32), name='accuracy_fake')
self.d_accuracy = tf.add(.5 * d_pos_acc, .5 * d_neg_acc, name='accuracy') self.d_accuracy = tf.add(.5 * d_pos_acc, .5 * d_neg_acc, name='accuracy')
self.d_loss = tf.add(.5 * d_loss_pos, .5 * d_loss_neg, name='loss') self.d_loss = tf.add(.5 * d_loss_pos, .5 * d_loss_neg, name='loss')
with tf.name_scope("gen"): with tf.name_scope("gen"):
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.zeros_like(logits_fake)), name='loss') logits=logits_fake, labels=tf.ones_like(logits_fake)), name='loss')
self.g_accuracy = tf.reduce_mean(tf.cast(score_fake < 0.5, tf.float32), name='accuracy') self.g_accuracy = tf.reduce_mean(tf.cast(score_fake > 0.5, tf.float32), name='accuracy')
add_moving_summary(self.g_loss, self.d_loss, self.d_accuracy, self.g_accuracy) add_moving_summary(self.g_loss, self.d_loss, self.d_accuracy, self.g_accuracy)
...@@ -76,6 +71,7 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -76,6 +71,7 @@ class GANTrainer(FeedfreeTrainerBase):
self.build_train_tower() self.build_train_tower()
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
# by default, run one d_min after one g_min
self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op') self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
with tf.control_dependencies([self.g_min]): with tf.control_dependencies([self.g_min]):
self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op') self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
......
...@@ -35,6 +35,7 @@ with tf.Graph().as_default() as G: ...@@ -35,6 +35,7 @@ with tf.Graph().as_default() as G:
else: else:
init = sessinit.SaverRestore(args.model) init = sessinit.SaverRestore(args.model)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer())
init.init(sess) init.init(sess)
# dump ... # dump ...
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment