Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7a008efe
Commit
7a008efe
authored
Feb 14, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
swap back 0,1 labels of GAN (#107)
parent
14864868
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
21 deletions
+18
-21
.travis.yml
.travis.yml
+5
-5
examples/GAN/DCGAN-CelebA.py
examples/GAN/DCGAN-CelebA.py
+5
-5
examples/GAN/GAN.py
examples/GAN/GAN.py
+7
-11
scripts/dump-model-params.py
scripts/dump-model-params.py
+1
-0
No files found.
.travis.yml
View file @
7a008efe
...
...
@@ -16,18 +16,18 @@ matrix:
include
:
-
os
:
linux
python
:
2.7
env
:
TF_VERSION=1.0.0rc
1
TF_TYPE=release
env
:
TF_VERSION=1.0.0rc
2
TF_TYPE=release
-
os
:
linux
python
:
3.5
env
:
TF_VERSION=1.0.0rc
1
TF_TYPE=release
env
:
TF_VERSION=1.0.0rc
2
TF_TYPE=release
-
os
:
linux
python
:
2.7
env
:
TF_VERSION=1.0.0rc
1
TF_TYPE=nightly
env
:
TF_VERSION=1.0.0rc
2
TF_TYPE=nightly
-
os
:
linux
python
:
3.5
env
:
TF_VERSION=1.0.0rc
1
TF_TYPE=nightly
env
:
TF_VERSION=1.0.0rc
2
TF_TYPE=nightly
allow_failures
:
-
env
:
TF_
VERSION=1.0.0rc1 TF_
TYPE=nightly
-
env
:
TF_TYPE=nightly
install
:
-
pip install -U pip
# the pip version on travis is too old
...
...
examples/GAN/DCGAN-CelebA.py
View file @
7a008efe
...
...
@@ -22,6 +22,7 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
DCGAN on CelebA dataset.
1. Download the 'aligned&cropped' version of CelebA dataset
from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
(or just use any directory of jpg files).
2. Start training:
./DCGAN-CelebA.py --data /path/to/image_align_celeba/
...
...
@@ -39,7 +40,7 @@ class Model(GANModelDesc):
return
[
InputDesc
(
tf
.
float32
,
(
None
,
SHAPE
,
SHAPE
,
3
),
'input'
)]
def
generator
(
self
,
z
):
""" return a image generated from z"""
""" return a
n
image generated from z"""
nf
=
64
l
=
FullyConnected
(
'fc0'
,
z
,
nf
*
8
*
4
*
4
,
nl
=
tf
.
identity
)
l
=
tf
.
reshape
(
l
,
[
-
1
,
4
,
4
,
nf
*
8
])
...
...
@@ -79,7 +80,7 @@ class Model(GANModelDesc):
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
)):
with
tf
.
variable_scope
(
'gen'
):
image_gen
=
self
.
generator
(
z
)
tf
.
summary
.
image
(
'gen'
,
image_gen
,
max_outputs
=
30
)
tf
.
summary
.
image
(
'gen
erated-samples
'
,
image_gen
,
max_outputs
=
30
)
with
tf
.
variable_scope
(
'discrim'
):
vecpos
=
self
.
discriminator
(
image_pos
)
with
tf
.
variable_scope
(
'discrim'
,
reuse
=
True
):
...
...
@@ -106,10 +107,8 @@ def get_data():
def
get_config
():
logger
.
auto_set_dir
()
dataset
=
get_data
()
return
TrainConfig
(
dataflow
=
dataset
,
dataflow
=
get_data
()
,
callbacks
=
[
ModelSaver
()],
session_config
=
get_default_sess_config
(
0.5
),
model
=
Model
(),
...
...
@@ -145,6 +144,7 @@ if __name__ == '__main__':
sample
(
args
.
load
)
else
:
assert
args
.
data
logger
.
auto_set_dir
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
...
...
examples/GAN/GAN.py
View file @
7a008efe
...
...
@@ -31,14 +31,9 @@ class GANModelDesc(ModelDesc):
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Note, we swap 0, 1 labels as suggested in "Improving GANs".
Args:
logits_real (tf.Tensor): discrim logits from real samples
logits_fake (tf.Tensor): discrim logits from fake samples produced by generator
Returns:
tf.Tensor: Description
"""
with
tf
.
name_scope
(
"GAN_loss"
):
score_real
=
tf
.
sigmoid
(
logits_real
)
...
...
@@ -48,20 +43,20 @@ class GANModelDesc(ModelDesc):
with
tf
.
name_scope
(
"discrim"
):
d_loss_pos
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
logits_real
,
labels
=
tf
.
zero
s_like
(
logits_real
)),
name
=
'loss_real'
)
logits
=
logits_real
,
labels
=
tf
.
one
s_like
(
logits_real
)),
name
=
'loss_real'
)
d_loss_neg
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
logits_fake
,
labels
=
tf
.
one
s_like
(
logits_fake
)),
name
=
'loss_fake'
)
logits
=
logits_fake
,
labels
=
tf
.
zero
s_like
(
logits_fake
)),
name
=
'loss_fake'
)
d_pos_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
score_real
<
0.5
,
tf
.
float32
),
name
=
'accuracy_real'
)
d_neg_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
score_fake
>
0.5
,
tf
.
float32
),
name
=
'accuracy_fake'
)
d_pos_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
score_real
>
0.5
,
tf
.
float32
),
name
=
'accuracy_real'
)
d_neg_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
score_fake
<
0.5
,
tf
.
float32
),
name
=
'accuracy_fake'
)
self
.
d_accuracy
=
tf
.
add
(
.5
*
d_pos_acc
,
.5
*
d_neg_acc
,
name
=
'accuracy'
)
self
.
d_loss
=
tf
.
add
(
.5
*
d_loss_pos
,
.5
*
d_loss_neg
,
name
=
'loss'
)
with
tf
.
name_scope
(
"gen"
):
self
.
g_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
logits_fake
,
labels
=
tf
.
zero
s_like
(
logits_fake
)),
name
=
'loss'
)
self
.
g_accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
score_fake
<
0.5
,
tf
.
float32
),
name
=
'accuracy'
)
logits
=
logits_fake
,
labels
=
tf
.
one
s_like
(
logits_fake
)),
name
=
'loss'
)
self
.
g_accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
score_fake
>
0.5
,
tf
.
float32
),
name
=
'accuracy'
)
add_moving_summary
(
self
.
g_loss
,
self
.
d_loss
,
self
.
d_accuracy
,
self
.
g_accuracy
)
...
...
@@ -76,6 +71,7 @@ class GANTrainer(FeedfreeTrainerBase):
self
.
build_train_tower
()
opt
=
self
.
model
.
get_optimizer
()
# by default, run one d_min after one g_min
self
.
g_min
=
opt
.
minimize
(
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
,
name
=
'g_op'
)
with
tf
.
control_dependencies
([
self
.
g_min
]):
self
.
d_min
=
opt
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_op'
)
...
...
scripts/dump-model-params.py
View file @
7a008efe
...
...
@@ -35,6 +35,7 @@ with tf.Graph().as_default() as G:
else
:
init
=
sessinit
.
SaverRestore
(
args
.
model
)
sess
=
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
sess
.
run
(
tf
.
global_variables_initializer
())
init
.
init
(
sess
)
# dump ...
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment