Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9f22aa91
Commit
9f22aa91
authored
May 23, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add cyclegan
parent
00811100
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
241 additions
and
6 deletions
+241
-6
README.md
README.md
+1
-1
examples/GAN/CycleGAN.py
examples/GAN/CycleGAN.py
+227
-0
examples/GAN/GAN.py
examples/GAN/GAN.py
+3
-0
examples/GAN/README.md
examples/GAN/README.md
+7
-2
examples/README.md
examples/README.md
+1
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-1
No files found.
README.md
View file @
9f22aa91
...
@@ -9,7 +9,7 @@ See some [examples](examples) to learn about the framework:
...
@@ -9,7 +9,7 @@ See some [examples](examples) to learn about the framework:
### Vision:
### Vision:
+
[
DoReFa-Net: train binary / low-bitwidth CNN on ImageNet
](
examples/DoReFa-Net
)
+
[
DoReFa-Net: train binary / low-bitwidth CNN on ImageNet
](
examples/DoReFa-Net
)
+
[
Train ResNet on ImageNet / Cifar10 / SVHN
](
examples/ResNet
)
+
[
Train ResNet on ImageNet / Cifar10 / SVHN
](
examples/ResNet
)
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image.
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image
, CycleGAN
.
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
examples/HED
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
examples/HED
)
+
[
Spatial Transformer Networks on MNIST addition
](
examples/SpatialTransformer
)
+
[
Spatial Transformer Networks on MNIST addition
](
examples/SpatialTransformer
)
+
[
Visualize Saliency Maps by Guided ReLU
](
examples/Saliency
)
+
[
Visualize Saliency Maps by Guided ReLU
](
examples/Saliency
)
...
...
examples/GAN/CycleGAN.py
0 → 100755
View file @
9f22aa91
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: CycleGAN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
,
sys
import
argparse
import
glob
from
six.moves
import
map
,
zip
,
range
import
numpy
as
np
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
import
tensorflow
as
tf
from
GAN
import
GANTrainer
,
GANModelDesc
"""
1. Download the dataset following the original project: https://github.com/junyanz/CycleGAN#train
2. ./CycleGAN.py --data /path/to/datasets/horse2zebra
Training and testing visuliazations will be in tensorboard.
"""
SHAPE
=
256
BATCH
=
1
TEST_BATCH
=
32
NF
=
64
# channel size
def
INReLU
(
x
,
name
=
None
):
x
=
InstanceNorm
(
'inorm'
,
x
)
return
tf
.
nn
.
relu
(
x
,
name
=
name
)
def
INLReLU
(
x
,
name
=
None
):
x
=
InstanceNorm
(
'inorm'
,
x
)
return
LeakyReLU
(
x
,
name
=
name
)
class
Model
(
GANModelDesc
):
def
_get_inputs
(
self
):
return
[
InputDesc
(
tf
.
float32
,
(
None
,
SHAPE
,
SHAPE
,
3
),
'inputA'
),
InputDesc
(
tf
.
float32
,
(
None
,
SHAPE
,
SHAPE
,
3
),
'inputB'
)]
@
staticmethod
def
build_res_block
(
x
,
name
,
chan
,
first
=
False
):
with
tf
.
variable_scope
(
name
):
input
=
x
return
(
LinearWrap
(
x
)
.
tf
.
pad
([[
0
,
0
],
[
0
,
0
],
[
1
,
1
],
[
1
,
1
]],
mode
=
'SYMMETRIC'
)
.
Conv2D
(
'conv0'
,
chan
,
padding
=
'VALID'
)
.
tf
.
pad
([[
0
,
0
],
[
0
,
0
],
[
1
,
1
],
[
1
,
1
]],
mode
=
'SYMMETRIC'
)
.
Conv2D
(
'conv1'
,
chan
,
padding
=
'VALID'
,
nl
=
tf
.
identity
)
.
InstanceNorm
(
'inorm'
)())
+
input
@
auto_reuse_variable_scope
def
generator
(
self
,
img
):
assert
img
is
not
None
with
argscope
([
Conv2D
,
Deconv2D
],
nl
=
INReLU
,
kernel_shape
=
3
):
l
=
(
LinearWrap
(
img
)
.
tf
.
pad
([[
0
,
0
],
[
0
,
0
],
[
3
,
3
],
[
3
,
3
]],
mode
=
'SYMMETRIC'
)
.
Conv2D
(
'conv0'
,
NF
,
kernel_shape
=
7
,
padding
=
'VALID'
)
.
Conv2D
(
'conv1'
,
NF
*
2
,
stride
=
2
)
.
Conv2D
(
'conv2'
,
NF
*
4
,
stride
=
2
)())
for
k
in
range
(
9
):
l
=
Model
.
build_res_block
(
l
,
'res{}'
.
format
(
k
),
NF
*
4
,
first
=
(
k
==
0
))
l
=
(
LinearWrap
(
l
)
.
Deconv2D
(
'deconv0'
,
NF
*
2
,
stride
=
2
)
.
Deconv2D
(
'deconv1'
,
NF
*
1
,
stride
=
2
)
.
tf
.
pad
([[
0
,
0
],
[
0
,
0
],
[
3
,
3
],
[
3
,
3
]],
mode
=
'SYMMETRIC'
)
.
Conv2D
(
'convlast'
,
3
,
kernel_shape
=
7
,
padding
=
'VALID'
,
nl
=
tf
.
tanh
,
use_bias
=
True
)())
return
l
@
auto_reuse_variable_scope
def
discriminator
(
self
,
img
):
with
argscope
(
Conv2D
,
nl
=
INLReLU
,
kernel_shape
=
4
,
stride
=
2
):
l
=
(
LinearWrap
(
img
)
.
Conv2D
(
'conv0'
,
NF
,
nl
=
LeakyReLU
)
.
Conv2D
(
'conv1'
,
NF
*
2
)
.
Conv2D
(
'conv2'
,
NF
*
4
)
.
Conv2D
(
'conv3'
,
NF
*
8
,
stride
=
1
)
.
Conv2D
(
'conv4'
,
1
,
stride
=
1
,
nl
=
tf
.
identity
,
use_bias
=
True
)())
return
l
def
_build_graph
(
self
,
inputs
):
A
,
B
=
inputs
A
=
tf
.
transpose
(
A
/
128.0
-
1.0
,
[
0
,
3
,
1
,
2
])
B
=
tf
.
transpose
(
B
/
128.0
-
1.0
,
[
0
,
3
,
1
,
2
])
def
viz3
(
name
,
a
,
b
,
c
):
im
=
tf
.
concat
([
a
,
b
,
c
],
axis
=
3
)
im
=
tf
.
transpose
(
im
,
[
0
,
2
,
3
,
1
])
im
=
(
im
+
1.0
)
*
128
im
=
tf
.
clip_by_value
(
im
,
0
,
255
)
im
=
tf
.
cast
(
im
,
tf
.
uint8
,
name
=
'viz_'
+
name
)
tf
.
summary
.
image
(
name
,
im
,
max_outputs
=
50
)
# use the initializers from torch
with
argscope
([
Conv2D
,
Deconv2D
],
use_bias
=
False
,
W_init
=
tf
.
random_normal_initializer
(
stddev
=
0.02
)),
\
argscope
([
Conv2D
,
Deconv2D
,
InstanceNorm
],
data_format
=
'NCHW'
),
\
argscope
(
LeakyReLU
,
alpha
=
0.2
):
with
tf
.
variable_scope
(
'gen'
):
with
tf
.
variable_scope
(
'B'
):
AB
=
self
.
generator
(
A
)
with
tf
.
variable_scope
(
'A'
):
BA
=
self
.
generator
(
B
)
ABA
=
self
.
generator
(
AB
)
with
tf
.
variable_scope
(
'B'
):
BAB
=
self
.
generator
(
BA
)
viz3
(
'A_recon'
,
A
,
AB
,
ABA
)
viz3
(
'B_recon'
,
B
,
BA
,
BAB
)
with
tf
.
variable_scope
(
'discrim'
):
with
tf
.
variable_scope
(
'A'
):
A_dis_real
=
self
.
discriminator
(
A
)
A_dis_fake
=
self
.
discriminator
(
BA
)
with
tf
.
variable_scope
(
'B'
):
B_dis_real
=
self
.
discriminator
(
B
)
B_dis_fake
=
self
.
discriminator
(
AB
)
def
LSGAN_losses
(
real
,
fake
):
with
tf
.
name_scope
(
'LSGAN_losses'
):
d_real
=
tf
.
reduce_mean
(
tf
.
squared_difference
(
real
,
0.9
),
name
=
'd_real'
)
d_fake
=
tf
.
reduce_mean
(
tf
.
square
(
fake
),
name
=
'd_fake'
)
d_loss
=
tf
.
multiply
(
d_real
+
d_fake
,
0.5
,
name
=
'd_loss'
)
g_loss
=
tf
.
reduce_mean
(
tf
.
squared_difference
(
fake
,
0.9
),
name
=
'g_loss'
)
add_moving_summary
(
g_loss
,
d_loss
)
return
g_loss
,
d_loss
with
tf
.
name_scope
(
'LossA'
):
# reconstruction loss
recon_loss_A
=
tf
.
reduce_mean
(
tf
.
abs
(
A
-
ABA
),
name
=
'recon_loss'
)
# gan loss
G_loss_A
,
D_loss_A
=
LSGAN_losses
(
A_dis_real
,
A_dis_fake
)
with
tf
.
name_scope
(
'LossB'
):
recon_loss_B
=
tf
.
reduce_mean
(
tf
.
abs
(
B
-
BAB
),
name
=
'recon_loss'
)
G_loss_B
,
D_loss_B
=
LSGAN_losses
(
B_dis_real
,
B_dis_fake
)
LAMBDA
=
10.0
self
.
g_loss
=
tf
.
add
((
G_loss_A
+
G_loss_B
),
(
recon_loss_A
+
recon_loss_B
)
*
LAMBDA
,
name
=
'G_loss_total'
)
self
.
d_loss
=
tf
.
add
(
D_loss_A
,
D_loss_B
,
name
=
'D_loss_total'
)
self
.
collect_variables
(
'gen'
,
'discrim'
)
add_moving_summary
(
recon_loss_A
,
recon_loss_B
,
self
.
g_loss
,
self
.
d_loss
)
def
_get_optimizer
(
self
):
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
2e-4
,
summary
=
True
)
return
tf
.
train
.
AdamOptimizer
(
lr
,
beta1
=
0.5
,
epsilon
=
1e-3
)
def
get_data
(
datadir
,
isTrain
=
True
):
if
isTrain
:
augs
=
[
imgaug
.
Resize
(
int
(
SHAPE
*
1.12
)),
imgaug
.
RandomCrop
(
SHAPE
),
]
else
:
augs
=
[
imgaug
.
Resize
(
SHAPE
)]
def
get_image_pairs
(
dir1
,
dir2
):
def
get_df
(
dir
):
files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
dir
,
'*.jpg'
)))
df
=
ImageFromFile
(
files
,
channel
=
3
,
shuffle
=
isTrain
)
return
AugmentImageComponent
(
df
,
augs
)
return
JoinData
([
get_df
(
dir1
),
get_df
(
dir2
)])
names
=
[
'trainA'
,
'trainB'
]
if
isTrain
else
[
'testA'
,
'testB'
]
df
=
get_image_pairs
(
*
[
os
.
path
.
join
(
datadir
,
n
)
for
n
in
names
])
df
=
BatchData
(
df
,
BATCH
if
isTrain
else
TEST_BATCH
)
df
=
PrefetchDataZMQ
(
df
,
2
if
isTrain
else
1
)
return
df
class
VisualizeTestSet
(
Callback
):
def
_setup_graph
(
self
):
self
.
pred
=
self
.
trainer
.
get_predictor
(
[
'inputA'
,
'inputB'
],
[
'viz_A_recon'
,
'viz_B_recon'
])
def
_before_train
(
self
):
global
args
self
.
val_ds
=
get_data
(
args
.
data
,
isTrain
=
False
)
def
_trigger
(
self
):
idx
=
0
for
iA
,
iB
in
self
.
val_ds
.
get_data
():
vizA
,
vizB
=
self
.
pred
(
iA
,
iB
)
self
.
trainer
.
monitors
.
put_image
(
'testA-{}'
.
format
(
idx
),
vizA
)
self
.
trainer
.
monitors
.
put_image
(
'testB-{}'
.
format
(
idx
),
vizB
)
idx
+=
1
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--data'
,
required
=
True
,
help
=
'the image directory. should contain trainA/trainB/testA/testB'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
logger
.
auto_set_dir
()
data
=
get_data
(
args
.
data
)
data
=
PrintData
(
data
)
config
=
TrainConfig
(
model
=
Model
(),
dataflow
=
data
,
callbacks
=
[
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
100
,
2e-4
),
(
200
,
0
)],
interp
=
'linear'
),
PeriodicTrigger
(
VisualizeTestSet
(),
every_k_epochs
=
3
),
],
max_epoch
=
195
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
GANTrainer
(
config
)
.
train
()
examples/GAN/GAN.py
View file @
9f22aa91
...
@@ -109,6 +109,9 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
...
@@ -109,6 +109,9 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
class
MultiGPUGANTrainer
(
MultiGPUTrainerBase
,
FeedfreeTrainerBase
):
class
MultiGPUGANTrainer
(
MultiGPUTrainerBase
,
FeedfreeTrainerBase
):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
MultiGPUGANTrainer
,
self
)
.
__init__
(
config
)
super
(
MultiGPUGANTrainer
,
self
)
.
__init__
(
config
)
self
.
_nr_gpu
=
config
.
nr_tower
self
.
_nr_gpu
=
config
.
nr_tower
...
...
examples/GAN/README.md
View file @
9f22aa91
...
@@ -18,6 +18,9 @@ Reproduce the following GAN-related methods:
...
@@ -18,6 +18,9 @@ Reproduce the following GAN-related methods:
+
BEGAN (
[
BEGAN: Boundary Equilibrium Generative Adversarial Networks
](
https://arxiv.org/abs/1703.10717
)
)
+
BEGAN (
[
BEGAN: Boundary Equilibrium Generative Adversarial Networks
](
https://arxiv.org/abs/1703.10717
)
)
+
CycleGAN (
[
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
](
https://arxiv.org/abs/1703.10593
)
)
Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
## DCGAN.py
## DCGAN.py
...
@@ -65,6 +68,8 @@ Some BEGAN samples:
...
@@ -65,6 +68,8 @@ Some BEGAN samples:


## DiscoGAN-CelebA.py
## CycleGAN.py, DiscoGAN-CelebA.py
Reproduce CycleGAN with the original datasets, and DiscoGAN on CelebA. They are pretty much the same idea with different architecture.
Reproduce DiscoGAN on CelebA.

examples/README.md
View file @
9f22aa91
...
@@ -12,7 +12,7 @@ Training examples with __reproducible__ and meaningful performance.
...
@@ -12,7 +12,7 @@ Training examples with __reproducible__ and meaningful performance.
+
[
A tiny SVHN ConvNet with 97.8% accuracy
](
svhn-digit-convnet.py
)
+
[
A tiny SVHN ConvNet with 97.8% accuracy
](
svhn-digit-convnet.py
)
+
[
DoReFa-Net: training binary / low-bitwidth CNN on ImageNet
](
DoReFa-Net
)
+
[
DoReFa-Net: training binary / low-bitwidth CNN on ImageNet
](
DoReFa-Net
)
+
[
Train ResNet for ImageNet/Cifar10/SVHN
](
ResNet
)
+
[
Train ResNet for ImageNet/Cifar10/SVHN
](
ResNet
)
+
[
Generative Adversarial Network(GAN) variants
](
GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image.
+
[
Generative Adversarial Network(GAN) variants
](
GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image
, CycleGAN
.
+
[
Inception-BN with 71% accuracy
](
Inception/inception-bn.py
)
+
[
Inception-BN with 71% accuracy
](
Inception/inception-bn.py
)
+
[
InceptionV3 with 74% accuracy (similar to the official code)
](
Inception/inceptionv3.py
)
+
[
InceptionV3 with 74% accuracy (similar to the official code)
](
Inception/inceptionv3.py
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
HED
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
HED
)
...
...
tensorpack/dataflow/common.py
View file @
9f22aa91
...
@@ -578,7 +578,7 @@ class CacheData(ProxyDataFlow):
...
@@ -578,7 +578,7 @@ class CacheData(ProxyDataFlow):
class
PrintData
(
ProxyDataFlow
):
class
PrintData
(
ProxyDataFlow
):
"""
"""
Behave like an identity mapping but print shape
s of produced datapoints
once during construction.
Behave like an identity mapping but print shape
and range of the first datapoint
once during construction.
Attributes:
Attributes:
label (str): label to identify the data when using this debugging on multiple places.
label (str): label to identify the data when using this debugging on multiple places.
...
...
tensorpack/train/multigpu.py
View file @
9f22aa91
...
@@ -29,7 +29,7 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
...
@@ -29,7 +29,7 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
def
_check_tf_version
():
def
_check_tf_version
():
ver
=
float
(
'.'
.
join
(
tf
.
VERSION
.
split
(
'.'
)[:
2
]))
ver
=
float
(
'.'
.
join
(
tf
.
VERSION
.
split
(
'.'
)[:
2
]))
assert
ver
>=
1.1
,
"TF version {} is too old to run multi GPU training!"
.
format
(
ver
)
assert
ver
>=
1.1
,
"TF version {} is too old to run multi GPU training!"
.
format
(
tf
.
VERSION
)
def
apply_prefetch_policy
(
config
,
use_stage
=
True
):
def
apply_prefetch_policy
(
config
,
use_stage
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment