Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
52a4a0a8
Commit
52a4a0a8
authored
Nov 23, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Image to Image
parent
d31ba459
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
222 additions
and
12 deletions
+222
-12
examples/GAN/DCGAN-CelebA.py
examples/GAN/DCGAN-CelebA.py
+4
-2
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+203
-0
examples/GAN/README.md
examples/GAN/README.md
+8
-1
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+1
-4
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+2
-2
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+1
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+1
-1
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+1
-0
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+1
-1
No files found.
examples/GAN/DCGAN-CelebA.py
View file @
52a4a0a8
...
@@ -18,11 +18,13 @@ from GAN import GANTrainer, RandomZData, build_GAN_losses
...
@@ -18,11 +18,13 @@ from GAN import GANTrainer, RandomZData, build_GAN_losses
"""
"""
DCGAN on CelebA dataset.
DCGAN on CelebA dataset.
The original code (dcgan.torch) uses kernel_shape=4, but I found the difference not significant.
1. Download the 'aligned&cropped' version of CelebA dataset.
1. Download the 'aligned&cropped' version of CelebA dataset.
2. Start training:
2. Start training:
./
c
elebA.py --data /path/to/image_align_celeba/
./
DCGAN-C
elebA.py --data /path/to/image_align_celeba/
3. Visualize samples of a trained model:
3. Visualize samples of a trained model:
./
c
elebA.py --load model.tfmodel --sample
./
DCGAN-C
elebA.py --load model.tfmodel --sample
"""
"""
SHAPE
=
64
SHAPE
=
64
...
...
examples/GAN/Image2Image.py
0 → 100755
View file @
52a4a0a8
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: Image2Image.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
import
tensorflow
as
tf
import
glob
,
pickle
import
os
,
sys
import
argparse
import
cv2
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
,
summary_moving_average
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
GAN
import
GANTrainer
,
RandomZData
,
build_GAN_losses
"""
To train:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain many 512x256 images formed by A and B
To visualize:
./Image2Image.py --data /path/to/test/datadir --mode {AtoB,BtoA} --load pretrained.model
"""
SHAPE
=
256
BATCH
=
16
IN_CH
=
3
OUT_CH
=
3
LAMBDA
=
100
NF
=
64
# number of filter
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
return
[
InputVar
(
tf
.
float32
,
(
None
,
SHAPE
,
SHAPE
,
IN_CH
),
'input'
)
,
InputVar
(
tf
.
float32
,
(
None
,
SHAPE
,
SHAPE
,
OUT_CH
),
'output'
)
]
def
generator
(
self
,
imgs
):
# imgs: input: 256x256xch
# U-Net structure, slightly different from the original on the location of relu/lrelu
with
argscope
(
BatchNorm
,
use_local_stat
=
True
),
\
argscope
(
Dropout
,
is_training
=
True
):
# always use local stat for BN, and apply dropout even in testing
with
argscope
(
Conv2D
,
kernel_shape
=
4
,
stride
=
2
,
nl
=
lambda
x
,
name
:
LeakyReLU
(
BatchNorm
(
'bn'
,
x
),
name
=
name
)):
e1
=
Conv2D
(
'conv1'
,
imgs
,
NF
,
nl
=
LeakyReLU
)
e2
=
Conv2D
(
'conv2'
,
e1
,
NF
*
2
)
e3
=
Conv2D
(
'conv3'
,
e2
,
NF
*
4
)
e4
=
Conv2D
(
'conv4'
,
e3
,
NF
*
8
)
e5
=
Conv2D
(
'conv5'
,
e4
,
NF
*
8
)
e6
=
Conv2D
(
'conv6'
,
e5
,
NF
*
8
)
e7
=
Conv2D
(
'conv7'
,
e6
,
NF
*
8
)
e8
=
Conv2D
(
'conv8'
,
e7
,
NF
*
8
,
nl
=
BNReLU
)
# 1x1
with
argscope
(
Deconv2D
,
nl
=
BNReLU
,
kernel_shape
=
4
,
stride
=
2
):
return
(
LinearWrap
(
e8
)
.
Deconv2D
(
'deconv1'
,
NF
*
8
)
.
Dropout
()
.
ConcatWith
(
3
,
e7
)
.
Deconv2D
(
'deconv2'
,
NF
*
8
)
.
Dropout
()
.
ConcatWith
(
3
,
e6
)
.
Deconv2D
(
'deconv3'
,
NF
*
8
)
.
Dropout
()
.
ConcatWith
(
3
,
e5
)
.
Deconv2D
(
'deconv4'
,
NF
*
8
)
.
ConcatWith
(
3
,
e4
)
.
Deconv2D
(
'deconv5'
,
NF
*
4
)
.
ConcatWith
(
3
,
e3
)
.
Deconv2D
(
'deconv6'
,
NF
*
2
)
.
ConcatWith
(
3
,
e2
)
.
Deconv2D
(
'deconv7'
,
NF
*
1
)
.
ConcatWith
(
3
,
e1
)
.
Deconv2D
(
'deconv8'
,
OUT_CH
,
nl
=
tf
.
tanh
)())
def
discriminator
(
self
,
inputs
,
outputs
):
""" return a (b, 1) logits"""
l
=
tf
.
concat
(
3
,
[
inputs
,
outputs
])
with
argscope
(
Conv2D
,
nl
=
tf
.
identity
,
kernel_shape
=
4
,
stride
=
2
):
l
=
(
LinearWrap
(
l
)
.
Conv2D
(
'conv0'
,
NF
,
nl
=
LeakyReLU
)
.
Conv2D
(
'conv1'
,
NF
*
2
)
.
BatchNorm
(
'bn1'
)
.
LeakyReLU
()
.
Conv2D
(
'conv2'
,
NF
*
4
)
.
BatchNorm
(
'bn2'
)
.
LeakyReLU
()
.
Conv2D
(
'conv3'
,
NF
*
8
,
stride
=
1
)
# valid?
.
BatchNorm
(
'bn3'
)
.
LeakyReLU
()
.
Conv2D
(
'convlast'
,
1
,
stride
=
1
)())
return
l
def
_build_graph
(
self
,
input_vars
):
input
,
output
=
input_vars
input
,
output
=
input
/
128.0
-
1
,
output
/
128.0
-
1
with
argscope
([
Conv2D
,
Deconv2D
],
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
)),
\
argscope
(
LeakyReLU
,
alpha
=
0.2
):
with
tf
.
variable_scope
(
'gen'
):
fake_output
=
self
.
generator
(
input
)
with
tf
.
variable_scope
(
'discrim'
):
real_pred
=
self
.
discriminator
(
input
,
output
)
with
tf
.
variable_scope
(
'discrim'
,
reuse
=
True
):
fake_pred
=
self
.
discriminator
(
input
,
fake_output
)
self
.
g_loss
,
self
.
d_loss
=
build_GAN_losses
(
real_pred
,
fake_pred
)
errL1
=
tf
.
reduce_mean
(
tf
.
abs
(
fake_output
-
output
),
name
=
'L1_loss'
)
self
.
g_loss
=
tf
.
add
(
self
.
g_loss
,
LAMBDA
*
errL1
,
name
=
'total_g_loss'
)
add_moving_summary
(
errL1
,
self
.
g_loss
)
# visualization
if
IN_CH
==
1
:
input
=
tf
.
image
.
grayscale_to_rgb
(
input
)
if
OUT_CH
==
1
:
output
=
tf
.
image
.
grayscale_to_rgb
(
output
)
fake_output
=
tf
.
image
.
grayscale_to_rgb
(
fake_output
)
viz
=
(
tf
.
concat
(
2
,
[
input
,
output
,
fake_output
])
+
1.0
)
*
128.0
viz
=
tf
.
cast
(
viz
,
tf
.
uint8
,
name
=
'viz'
)
tf
.
image_summary
(
'gen'
,
viz
,
max_images
=
max
(
30
,
BATCH
))
all_vars
=
tf
.
trainable_variables
()
self
.
g_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
'gen/'
)]
self
.
d_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
'discrim/'
)]
def
split_input
(
img
):
"""
img: an 512x256x3 image
:return: [input, output]
"""
input
,
output
=
img
[:,:
256
,:],
img
[:,
256
:,:]
if
args
.
mode
==
'BtoA'
:
input
,
output
=
output
,
input
if
IN_CH
==
1
:
input
=
cv2
.
cvtColor
(
input
,
cv2
.
COLOR_RGB2GRAY
)
if
OUT_CH
==
1
:
output
=
cv2
.
cvtColor
(
output
,
cv2
.
COLOR_RGB2GRAY
)
return
[
input
,
output
]
def
get_data
():
datadir
=
args
.
data
# assume each image is 512x256 split to left and right
imgs
=
glob
.
glob
(
os
.
path
.
join
(
datadir
,
'*.jpg'
))
ds
=
ImageFromFile
(
imgs
,
channel
=
3
,
shuffle
=
True
)
ds
=
MapData
(
ds
,
lambda
dp
:
split_input
(
dp
[
0
]))
augs
=
[
imgaug
.
Resize
(
286
),
imgaug
.
RandomCrop
(
256
)
]
ds
=
AugmentImageComponents
(
ds
,
augs
,
(
0
,
1
))
ds
=
BatchData
(
ds
,
BATCH
)
ds
=
PrefetchDataZMQ
(
ds
,
1
)
return
ds
def
get_config
():
logger
.
auto_set_dir
()
dataset
=
get_data
()
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
2e-4
,
summary
=
True
)
return
TrainConfig
(
dataset
=
dataset
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
beta1
=
0.5
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
200
,
1e-4
)])
]),
session_config
=
get_default_sess_config
(
0.8
),
model
=
Model
(),
step_per_epoch
=
300
,
max_epoch
=
300
,
)
def
sample
(
datadir
,
model_path
):
pred
=
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
model
=
Model
(),
input_names
=
[
'input'
,
'output'
],
output_names
=
[
'viz'
])
imgs
=
glob
.
glob
(
os
.
path
.
join
(
datadir
,
'*.jpg'
))
ds
=
ImageFromFile
(
imgs
,
channel
=
3
,
shuffle
=
True
)
ds
=
BatchData
(
MapData
(
ds
,
lambda
dp
:
split_input
(
dp
[
0
])),
16
)
pred
=
SimpleDatasetPredictor
(
pred
,
ds
)
for
o
in
pred
.
get_result
():
o
=
o
[:,:,:,::
-
1
]
viz
=
next
(
build_patch_list
(
o
,
nr_row
=
4
,
nr_col
=
4
,
viz
=
True
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--sample'
,
action
=
'store_true'
,
help
=
'run sampling'
)
parser
.
add_argument
(
'--data'
,
help
=
'A directory of images'
)
parser
.
add_argument
(
'--mode'
,
choices
=
[
'AtoB'
,
'BtoA'
],
default
=
'AtoB'
)
global
args
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
if
args
.
sample
:
sample
(
args
.
data
,
args
.
load
)
else
:
assert
args
.
data
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
GANTrainer
(
config
,
g_vs_d
=
1
)
.
train
()
examples/GAN/README.md
View file @
52a4a0a8
# Deep Convolutional Generative Adversarial Networks
# Generative Adversarial Networks
## DCGAN-CelebA.py
Reproduce DCGAN following the setup in
[
dcgan.torch
](
https://github.com/soumith/dcgan.torch
)
.
Reproduce DCGAN following the setup in
[
dcgan.torch
](
https://github.com/soumith/dcgan.torch
)
.
...
@@ -13,3 +15,8 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv
...
@@ -13,3 +15,8 @@ Play with the [pretrained model](https://drive.google.com/drive/folders/0B9IPQTv


See the docstring in the script for usage.
See the docstring in the script for usage.
## Image2Image.py
Reproduce
[
Image-to-image Translation with Conditional Adversarial Networks
](
https://arxiv.org/pdf/1611.07004v1.pdf
)
,
following the setup in
[
pix2pix
](
https://github.com/phillipi/pix2pix
)
.
tensorpack/callbacks/dump.py
View file @
52a4a0a8
...
@@ -14,17 +14,14 @@ __all__ = ['DumpParamAsImage']
...
@@ -14,17 +14,14 @@ __all__ = ['DumpParamAsImage']
class
DumpParamAsImage
(
Callback
):
class
DumpParamAsImage
(
Callback
):
"""
"""
Dump a variable to image(s) after every epoch.
Dump a variable to image(s) after every epoch
to logger.LOG_DIR
.
"""
"""
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
"""
"""
:param var_name: the name of the variable.
:param var_name: the name of the variable.
:param prefix: the filename prefix for saved images. Default is the op name.
:param prefix: the filename prefix for saved images. Default is the op name.
:param map_func: map the value of the variable to an image or list of
:param map_func: map the value of the variable to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity
images of shape [h, w] or [h, w, c]. If None, will use identity
:param scale: a multiplier on pixel values, applied after map_func. default to 255
:param scale: a multiplier on pixel values, applied after map_func. default to 255
:param clip: whether to clip the result to [0, 255]
:param clip: whether to clip the result to [0, 255]
"""
"""
...
...
tensorpack/dataflow/image.py
View file @
52a4a0a8
...
@@ -22,6 +22,7 @@ class ImageFromFile(RNGDataFlow):
...
@@ -22,6 +22,7 @@ class ImageFromFile(RNGDataFlow):
assert
len
(
files
)
assert
len
(
files
)
self
.
files
=
files
self
.
files
=
files
self
.
channel
=
int
(
channel
)
self
.
channel
=
int
(
channel
)
self
.
imread_mode
=
cv2
.
IMREAD_GRAYSCALE
if
self
.
channel
==
1
else
cv2
.
IMREAD_COLOR
self
.
resize
=
resize
self
.
resize
=
resize
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
...
@@ -32,8 +33,7 @@ class ImageFromFile(RNGDataFlow):
...
@@ -32,8 +33,7 @@ class ImageFromFile(RNGDataFlow):
if
self
.
shuffle
:
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
self
.
files
)
self
.
rng
.
shuffle
(
self
.
files
)
for
f
in
self
.
files
:
for
f
in
self
.
files
:
im
=
cv2
.
imread
(
im
=
cv2
.
imread
(
f
,
self
.
imread_mode
)
f
,
cv2
.
IMREAD_GRAYSCALE
if
self
.
channel
==
1
else
cv2
.
IMREAD_COLOR
)
if
self
.
channel
==
3
:
if
self
.
channel
==
3
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
if
self
.
resize
is
not
None
:
if
self
.
resize
is
not
None
:
...
...
tensorpack/models/nonlin.py
View file @
52a4a0a8
...
@@ -66,6 +66,6 @@ def LeakyReLU(x, alpha, name=None):
...
@@ -66,6 +66,6 @@ def LeakyReLU(x, alpha, name=None):
@
layer_register
(
log_shape
=
False
,
use_scope
=
False
)
@
layer_register
(
log_shape
=
False
,
use_scope
=
False
)
def
BNReLU
(
x
,
name
=
None
):
def
BNReLU
(
x
,
name
=
None
):
x
=
BatchNorm
(
'bn'
,
x
,
use_local_stat
=
None
)
x
=
BatchNorm
(
'bn'
,
x
)
x
=
tf
.
nn
.
relu
(
x
,
name
=
name
)
x
=
tf
.
nn
.
relu
(
x
,
name
=
name
)
return
x
return
x
tensorpack/models/regularize.py
View file @
52a4a0a8
...
@@ -39,7 +39,7 @@ def regularize_cost(regex, func, name=None):
...
@@ -39,7 +39,7 @@ def regularize_cost(regex, func, name=None):
return
tf
.
add_n
(
costs
,
name
=
name
)
return
tf
.
add_n
(
costs
,
name
=
name
)
@
layer_register
(
log_shape
=
False
)
@
layer_register
(
log_shape
=
False
,
use_scope
=
False
)
def
Dropout
(
x
,
keep_prob
=
0.5
,
is_training
=
None
):
def
Dropout
(
x
,
keep_prob
=
0.5
,
is_training
=
None
):
"""
"""
:param is_training: if None, will use the current context by default.
:param is_training: if None, will use the current context by default.
...
...
tensorpack/predict/dataset.py
View file @
52a4a0a8
...
@@ -55,6 +55,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
...
@@ -55,6 +55,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
def
get_result
(
self
):
def
get_result
(
self
):
""" A generator to produce prediction for each data"""
""" A generator to produce prediction for each data"""
self
.
dataset
.
reset_state
()
try
:
try
:
sz
=
self
.
dataset
.
size
()
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
...
...
tensorpack/utils/viz.py
View file @
52a4a0a8
...
@@ -77,7 +77,7 @@ def build_patch_list(patch_list,
...
@@ -77,7 +77,7 @@ def build_patch_list(patch_list,
viz
=
False
,
lclick_cb
=
None
):
viz
=
False
,
lclick_cb
=
None
):
"""
"""
Generate patches.
Generate patches.
:param patch_list: bhw or bhwc
:param patch_list: bhw or bhwc
images in [0,255]
:param border: defaults to 0.1 * max(image_width, image_height)
:param border: defaults to 0.1 * max(image_width, image_height)
:param nr_row, nr_col: rows and cols of the grid
:param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment