Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4c82fb50
Commit
4c82fb50
authored
Mar 21, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move WGANTrainer to GAN
parent
0763b16d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
32 deletions
+42
-32
examples/GAN/GAN.py
examples/GAN/GAN.py
+30
-0
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+7
-28
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+5
-4
No files found.
examples/GAN/GAN.py
View file @
4c82fb50
...
...
@@ -71,6 +71,36 @@ class GANTrainer(FeedfreeTrainerBase):
self
.
train_op
=
self
.
d_min
class
SplitGANTrainer
(
FeedfreeTrainerBase
):
""" A new trainer which runs two optimization ops with a certain ratio. """
def
__init__
(
self
,
config
,
d_interval
=
1
):
"""
Args:
d_interval: will run d_opt only after this many of g_opt.
"""
self
.
_input_method
=
QueueInput
(
config
.
dataflow
)
self
.
_d_interval
=
d_interval
super
(
SplitGANTrainer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
super
(
SplitGANTrainer
,
self
)
.
_setup
()
self
.
build_train_tower
()
opt
=
self
.
model
.
get_optimizer
()
self
.
d_min
=
opt
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_min'
)
self
.
g_min
=
opt
.
minimize
(
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
,
name
=
'g_min'
)
self
.
_cnt
=
0
def
run_step
(
self
):
self
.
_cnt
+=
1
if
self
.
_cnt
%
(
self
.
_d_interval
)
==
0
:
self
.
hooked_sess
.
run
(
self
.
d_min
)
else
:
self
.
hooked_sess
.
run
(
self
.
g_min
)
class
RandomZData
(
DataFlow
):
def
__init__
(
self
,
shape
):
super
(
RandomZData
,
self
)
.
__init__
()
...
...
examples/GAN/WGAN-CelebA.py
View file @
4c82fb50
...
...
@@ -9,7 +9,7 @@ import argparse
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
import
tensorflow
as
tf
from
GAN
import
GANTrainer
from
GAN
import
Split
GANTrainer
"""
Wasserstein-GAN.
...
...
@@ -61,36 +61,11 @@ def get_config():
# use the same data in the DCGAN example
dataflow
=
DCGAN
.
get_data
(
args
.
data
),
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
3
00
,
steps_per_epoch
=
5
00
,
max_epoch
=
200
,
)
class
WGANTrainer
(
FeedfreeTrainerBase
):
""" A new trainer which runs two optimization ops with 5:1 ratio.
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. just using the existing GANTrainer) also works well.
"""
def
__init__
(
self
,
config
):
self
.
_input_method
=
QueueInput
(
config
.
dataflow
)
super
(
WGANTrainer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
super
(
WGANTrainer
,
self
)
.
_setup
()
self
.
build_train_tower
()
opt
=
self
.
model
.
get_optimizer
()
self
.
d_min
=
opt
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_min'
)
self
.
g_min
=
opt
.
minimize
(
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
,
name
=
'g_op'
)
def
run_step
(
self
):
for
k
in
range
(
5
):
self
.
hooked_sess
.
run
(
self
.
d_min
)
self
.
hooked_sess
.
run
(
self
.
g_min
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
...
...
@@ -105,4 +80,8 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
WGANTrainer
(
config
)
.
train
()
"""
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. just using the existing GANTrainer) also works well.
"""
SplitGANTrainer
(
config
,
d_interval
=
5
)
.
train
()
tensorpack/tfutils/summary.py
View file @
4c82fb50
...
...
@@ -130,9 +130,10 @@ def add_moving_summary(v, *args, **kwargs):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
decay
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
v
)
for
c
in
v
:
# TODO do this in the EMA callback?
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
for
c
in
v
:
# TODO do this in the EMA callback?
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
tf
.
add_to_collection
(
coll
,
avg_maintain_op
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment