Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
13c96b94
Commit
13c96b94
authored
Feb 15, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add a WGAN example
parent
00fdb263
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
117 additions
and
30 deletions
+117
-30
examples/GAN/README.md
examples/GAN/README.md
+6
-0
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+104
-0
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-7
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+1
-1
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+4
-22
No files found.
examples/GAN/README.md
View file @
13c96b94
...
...
@@ -10,6 +10,8 @@ Reproduce the following GAN-related methods:
+
Conditional GAN
+
[
Wasserstein GAN
](
https://arxiv.org/abs/1701.07875
)
Please see the __docstring__ in each script for detailed usage.
## DCGAN-CelebA.py
...
...
@@ -51,3 +53,7 @@ It then maximizes mutual information between these latent variables and the imag
## ConditionalGAN-mnist.py
Train a simple GAN on mnist, conditioned on the class labels.
## WGAN-CelebA.py
Reproduce WGAN by some small modifications on DCGAN-CelebA.py.
examples/GAN/WGAN-CelebA.py
0 → 100755
View file @
13c96b94
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: WGAN-CelebA.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
os
import
argparse
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
GAN
import
GANTrainer
"""
Wasserstein-GAN.
See the docstring in DCGAN-CelebA.py for usage.
Actually, just using the clip is enough for WGAN to work (even without BN in generator).
The wasserstein loss is not the key factor.
"""
# Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN-CelebA, and change the batch size & model
import
imp
DCGAN
=
imp
.
load_source
(
'DCGAN'
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'DCGAN-CelebA.py'
))
class
Model
(
DCGAN
.
Model
):
# def generator(self, z):
# you can override generator to remove BatchNorm, it will still work in WGAN
def
build_losses
(
self
,
vecpos
,
vecneg
):
# the Wasserstein-GAN losses
self
.
d_loss
=
tf
.
reduce_mean
(
vecneg
-
vecpos
,
name
=
'd_loss'
)
self
.
g_loss
=
-
tf
.
reduce_mean
(
vecneg
,
name
=
'g_loss'
)
add_moving_summary
(
self
.
d_loss
,
self
.
g_loss
)
def
_get_optimizer
(
self
):
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
1e-4
,
summary
=
True
)
return
tf
.
train
.
RMSPropOptimizer
(
lr
)
DCGAN
.
BATCH
=
64
DCGAN
.
Model
=
Model
def
get_config
():
return
TrainConfig
(
model
=
Model
(),
# use the same data in the DCGAN example
dataflow
=
DCGAN
.
get_data
(
args
.
data
),
callbacks
=
[
ModelSaver
()],
session_config
=
get_default_sess_config
(
0.5
),
steps_per_epoch
=
300
,
max_epoch
=
200
,
)
class
WGANTrainer
(
FeedfreeTrainerBase
):
def
__init__
(
self
,
config
):
self
.
_input_method
=
QueueInput
(
config
.
dataflow
)
super
(
WGANTrainer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
super
(
WGANTrainer
,
self
)
.
_setup
()
self
.
build_train_tower
()
# add clipping to D optimizer
def
clip
(
p
):
n
=
p
.
op
.
name
logger
.
info
(
"Clip {}"
.
format
(
n
))
return
tf
.
clip_by_value
(
p
,
-
0.01
,
0.01
)
opt_G
=
self
.
model
.
get_optimizer
()
opt_D
=
optimizer
.
VariableAssignmentOptimizer
(
opt_G
,
clip
)
self
.
d_min
=
opt_D
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_min'
)
self
.
g_min
=
opt_G
.
minimize
(
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
,
name
=
'g_op'
)
def
run_step
(
self
):
for
k
in
range
(
5
):
self
.
sess
.
run
(
self
.
d_min
)
ret
=
self
.
sess
.
run
([
self
.
g_min
]
+
self
.
get_extra_fetches
())
return
ret
[
1
:]
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--sample'
,
action
=
'store_true'
,
help
=
'view generated examples'
)
parser
.
add_argument
(
'--data'
,
help
=
'a jpeg directory'
)
args
=
parser
.
parse_args
()
if
args
.
sample
:
DCGAN
.
sample
(
args
.
load
)
else
:
assert
args
.
data
logger
.
auto_set_dir
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
WGANTrainer
(
config
)
.
train
()
tensorpack/tfutils/sessinit.py
View file @
13c96b94
...
...
@@ -17,13 +17,10 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore'
,
'ChainInit'
,
'JustCurrentSession'
,
'get_model_loader'
]
# TODO they initialize_all at the beginning by default.
@
six
.
add_metaclass
(
ABCMeta
)
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session. """
def
init
(
self
,
sess
):
"""
Initialize a session
...
...
@@ -40,7 +37,6 @@ class SessionInit(object):
class
JustCurrentSession
(
SessionInit
):
""" This is a no-op placeholder"""
def
_init
(
self
,
sess
):
pass
...
...
@@ -49,7 +45,6 @@ class NewSession(SessionInit):
"""
Initialize global variables by their initializer.
"""
def
_init
(
self
,
sess
):
sess
.
run
(
tf
.
global_variables_initializer
())
...
...
@@ -62,7 +57,7 @@ class SaverRestore(SessionInit):
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
Args:
model_path (str):
a model name
(model-xxxx) or a ``checkpoint`` file.
model_path (str):
path to the model
(model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
"""
model_path
=
get_checkpoint_path
(
model_path
)
...
...
@@ -150,7 +145,7 @@ class ChainInit(SessionInit):
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
"""
Args:
sess_inits (list): list of :class:`SessionInit` instances.
sess_inits (list
[SessionInit]
): list of :class:`SessionInit` instances.
new_session (bool): add a ``NewSession()`` and the beginning, if
not there.
"""
...
...
tensorpack/utils/utils.py
View file @
13c96b94
...
...
@@ -45,7 +45,7 @@ def change_env(name, val):
def
get_rng
(
obj
=
None
):
"""
Get a good RNG.
Get a good RNG
seeded with time, pid and the object
.
Args:
obj: some object to use to generate random seed.
...
...
tensorpack/utils/viz.py
View file @
13c96b94
...
...
@@ -17,7 +17,7 @@ except ImportError:
pass
__all__
=
[
'pyplot2img'
,
'
pyplot_viz'
,
'
interactive_imshow'
,
__all__
=
[
'pyplot2img'
,
'interactive_imshow'
,
'stack_patches'
,
'gen_stack_patches'
,
'dump_dataflow_images'
,
'intensity_to_rgb'
]
...
...
@@ -34,31 +34,13 @@ def pyplot2img(plt):
return
im
def
pyplot_viz
(
img
,
shape
=
None
):
""" Use pyplot to visualize the image. e.g., when input is grayscale, the result
will automatically have a colormap.
Returns:
np.ndarray: an image.
Note:
this is quite slow. and the returned image will have a border
"""
plt
.
clf
()
plt
.
axes
([
0
,
0
,
1
,
1
])
plt
.
imshow
(
img
)
ret
=
pyplot2img
(
plt
)
if
shape
is
not
None
:
ret
=
cv2
.
resize
(
ret
,
shape
)
return
ret
def
interactive_imshow
(
img
,
lclick_cb
=
None
,
rclick_cb
=
None
,
**
kwargs
):
"""
Args:
img (np.ndarray): an image (expect BGR) to show.
lclick_cb
: a callback func(img, x, y) for lef
t click event.
lclick_cb
, rclick_cb: a callback ``func(img, x, y)`` for left/righ
t click event.
kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to
specify a callback
func(img)
for keypress.
specify a callback
``func(img)``
for keypress.
Some existing keypress event handler:
...
...
@@ -187,7 +169,7 @@ def stack_patches(
nr_row(int), nr_col(int): rows and cols of the grid.
``nr_col * nr_row`` must be equal to ``len(patch_list)``.
border(int): border length between images.
Defaults to ``0.1 * min(
image_w, image_h
)``.
Defaults to ``0.1 * min(
patch_width, patch_height
)``.
pad (boolean): when `patch_list` is a list, pad all patches to the maximum height and width.
This option allows stacking patches of different shapes together.
bgcolor(int or 3-tuple): background color in [0, 255]. Either an int
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment