Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2d99afea
Commit
2d99afea
authored
Oct 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[WIP] Switch GANs to use Trainerv2
parent
17a73a4c
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
71 additions
and
54 deletions
+71
-54
examples/GAN/BEGAN.py
examples/GAN/BEGAN.py
+1
-0
examples/GAN/ConditionalGAN-mnist.py
examples/GAN/ConditionalGAN-mnist.py
+1
-0
examples/GAN/CycleGAN.py
examples/GAN/CycleGAN.py
+1
-0
examples/GAN/DCGAN.py
examples/GAN/DCGAN.py
+4
-3
examples/GAN/DiscoGAN-CelebA.py
examples/GAN/DiscoGAN-CelebA.py
+1
-0
examples/GAN/GAN.py
examples/GAN/GAN.py
+30
-32
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+1
-0
examples/GAN/Improved-WGAN.py
examples/GAN/Improved-WGAN.py
+1
-0
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+1
-0
examples/GAN/WGAN.py
examples/GAN/WGAN.py
+4
-3
tensorpack/train/base.py
tensorpack/train/base.py
+26
-16
No files found.
examples/GAN/BEGAN.py
View file @
2d99afea
...
...
@@ -6,6 +6,7 @@
import
os
import
argparse
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.gpu
import
get_nr_gpu
...
...
examples/GAN/ConditionalGAN-mnist.py
View file @
2d99afea
...
...
@@ -10,6 +10,7 @@ import sys
import
cv2
import
argparse
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
...
...
examples/GAN/CycleGAN.py
View file @
2d99afea
...
...
@@ -9,6 +9,7 @@ import glob
from
six.moves
import
map
,
zip
,
range
import
numpy
as
np
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
...
...
examples/GAN/DCGAN.py
View file @
2d99afea
...
...
@@ -8,6 +8,7 @@ import numpy as np
import
os
,
sys
import
argparse
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
...
...
@@ -156,11 +157,11 @@ if __name__ == '__main__':
assert
args
.
data
logger
.
auto_set_dir
()
config
=
TrainConfig
(
model
=
Model
(),
dataflow
=
get_data
(
args
.
data
),
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
300
,
max_epoch
=
200
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
GANTrainer
(
config
)
.
train
()
GANTrainer
(
input
=
QueueInput
(
get_data
(
args
.
data
)),
model
=
Model
())
.
train_with_config
(
config
)
examples/GAN/DiscoGAN-CelebA.py
View file @
2d99afea
...
...
@@ -8,6 +8,7 @@ import argparse
from
six.moves
import
map
,
zip
import
numpy
as
np
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
...
...
examples/GAN/GAN.py
View file @
2d99afea
...
...
@@ -6,9 +6,9 @@
import
tensorflow
as
tf
import
numpy
as
np
import
time
from
tensorpack
import
(
Trainer
,
QueueInput
,
from
tensorpack
import
(
T
owerT
rainer
,
QueueInput
,
ModelDescBase
,
DataFlow
,
StagingInput
,
TowerContext
)
TowerContext
,
TowerFuncWrapper
)
from
tensorpack.graph_builder
import
DataParallelBuilder
,
LeastLoadedDeviceSetter
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.argtools
import
memoized
...
...
@@ -64,20 +64,15 @@ class GANModelDesc(ModelDescBase):
return
self
.
_get_optimizer
()
class
GANTrainer
(
Trainer
):
def
__init__
(
self
,
config
):
"""
GANTrainer expects a ModelDesc in config which sets the following attribute
after :meth:`_build_graph`: g_loss, d_loss, g_vars, d_vars.
"""
input
=
QueueInput
(
config
.
dataflow
)
model
=
config
.
model
class
GANTrainer
(
TowerTrainer
):
def
__init__
(
self
,
input
,
model
):
super
(
GANTrainer
,
self
)
.
__init__
()
assert
isinstance
(
model
,
GANModelDesc
),
model
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
config
.
callbacks
.
extend
(
cbs
)
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_inputs_desc
())
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
input
)
tower_func
(
input
)
opt
=
model
.
get_optimizer
()
# by default, run one d_min after one g_min
...
...
@@ -86,29 +81,29 @@ class GANTrainer(Trainer):
with
tf
.
control_dependencies
([
g_min
]):
d_min
=
opt
.
minimize
(
model
.
d_loss
,
var_list
=
model
.
d_vars
,
name
=
'd_op'
)
self
.
train_op
=
d_min
self
.
set_tower_func
(
tower_func
)
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
for
cb
in
cbs
:
self
.
_register_callback
(
cb
)
class
SeparateGANTrainer
(
Trainer
):
class
SeparateGANTrainer
(
T
owerT
rainer
):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def
__init__
(
self
,
config
,
d_period
=
1
,
g_period
=
1
):
def
__init__
(
self
,
input
,
model
,
d_period
=
1
,
g_period
=
1
):
"""
Args:
d_period(int): period of each d_opt run
g_period(int): period of each g_opt run
"""
super
(
SeparateGANTrainer
,
self
)
.
__init__
()
self
.
_d_period
=
int
(
d_period
)
self
.
_g_period
=
int
(
g_period
)
assert
min
(
d_period
,
g_period
)
==
1
input
=
QueueInput
(
config
.
dataflow
)
model
=
config
.
model
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
config
.
callbacks
.
extend
(
cbs
)
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_inputs_desc
()
)
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
input
)
tower_func
(
input
)
opt
=
model
.
get_optimizer
()
with
tf
.
name_scope
(
'optimize'
):
...
...
@@ -117,7 +112,9 @@ class SeparateGANTrainer(Trainer):
self
.
g_min
=
opt
.
minimize
(
model
.
g_loss
,
var_list
=
model
.
g_vars
,
name
=
'g_min'
)
super
(
SeparateGANTrainer
,
self
)
.
__init__
(
config
)
self
.
set_tower_func
(
tower_func
)
for
cb
in
cbs
:
self
.
_register_callback
(
cb
)
def
run_step
(
self
):
if
self
.
global_step
%
(
self
.
_d_period
)
==
0
:
...
...
@@ -126,26 +123,25 @@ class SeparateGANTrainer(Trainer):
self
.
hooked_sess
.
run
(
self
.
g_min
)
class
MultiGPUGANTrainer
(
Trainer
):
class
MultiGPUGANTrainer
(
T
owerT
rainer
):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
def
__init__
(
self
,
config
):
nr_gpu
=
config
.
nr_tower
def
__init__
(
self
,
nr_gpu
,
input
,
model
):
super
(
MultiGPUGANTrainer
,
self
)
.
__init__
()
assert
nr_gpu
>
1
raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
config
.
tower
]
raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
range
(
nr_gpu
)
]
# setup input
input
=
StagingInput
(
QueueInput
(
config
.
dataflow
),
config
.
tower
)
model
=
config
.
model
input
=
StagingInput
(
input
,
list
(
range
(
nr_gpu
)))
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
config
.
callbacks
.
extend
(
cbs
)
def
get_cost
():
model
.
build_graph
(
input
)
model
.
build_graph
(
input
.
get_input_tensors
()
)
return
[
model
.
d_loss
,
model
.
g_loss
]
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_inputs_desc
())
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
DataParallelBuilder
.
build_on_towers
(
config
.
tower
,
get_cost
,
devices
)
cost_list
=
DataParallelBuilder
.
build_on_towers
(
list
(
range
(
nr_gpu
)),
tower_func
,
devices
)
# simply average the cost. It might get faster to average the gradients
with
tf
.
name_scope
(
'optimize'
):
d_loss
=
tf
.
add_n
([
x
[
0
]
for
x
in
cost_list
])
*
(
1.0
/
nr_gpu
)
...
...
@@ -159,7 +155,9 @@ class MultiGPUGANTrainer(Trainer):
d_min
=
opt
.
minimize
(
d_loss
,
var_list
=
model
.
d_vars
,
colocate_gradients_with_ops
=
True
,
name
=
'd_op'
)
self
.
train_op
=
d_min
super
(
MultiGPUGANTrainer
,
self
)
.
__init__
(
config
)
self
.
set_tower_func
(
tower_func
)
for
cb
in
cbs
:
self
.
_register_callback
(
cb
)
class
RandomZData
(
DataFlow
):
...
...
examples/GAN/Image2Image.py
View file @
2d99afea
...
...
@@ -12,6 +12,7 @@ import os
import
sys
import
argparse
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
...
...
examples/GAN/Improved-WGAN.py
View file @
2d99afea
...
...
@@ -6,6 +6,7 @@
import
os
import
argparse
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.globvars
import
globalns
as
G
...
...
examples/GAN/InfoGAN-mnist.py
View file @
2d99afea
...
...
@@ -10,6 +10,7 @@ import os
import
sys
import
argparse
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils
import
viz
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
,
under_name_scope
...
...
examples/GAN/WGAN.py
View file @
2d99afea
...
...
@@ -6,6 +6,7 @@
import
os
import
argparse
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.tfutils
import
optimizer
from
tensorpack.tfutils.summary
import
add_moving_summary
...
...
@@ -76,8 +77,6 @@ if __name__ == '__main__':
assert
args
.
data
logger
.
auto_set_dir
()
config
=
TrainConfig
(
model
=
Model
(),
dataflow
=
DCGAN
.
get_data
(
args
.
data
),
callbacks
=
[
ModelSaver
(),
ClipCallback
()],
steps_per_epoch
=
500
,
max_epoch
=
200
,
...
...
@@ -85,4 +84,6 @@ if __name__ == '__main__':
)
# The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G
SeparateGANTrainer
(
config
,
d_period
=
3
)
.
train
()
SeparateGANTrainer
(
input
=
QueueInput
(
DCGAN
.
get_data
(
args
.
data
)),
model
=
Model
(),
d_period
=
3
)
.
train_with_config
(
config
)
tensorpack/train/base.py
View file @
2d99afea
...
...
@@ -218,6 +218,28 @@ class Trainer(object):
self
.
initialize
(
session_creator
,
session_init
)
self
.
main_loop
(
steps_per_epoch
,
starting_epoch
,
max_epoch
)
def
train_with_config
(
self
,
config
):
"""
An alias to simplify the use of `TrainConfig`.
It is equivalent to the following:
.. code-block:: python
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
"""
if
config
.
data
or
config
.
dataflow
or
config
.
model
:
logger
.
warn
(
"data/dataflow/model in TrainConfig will not be used "
"in `Trainer.train_with_config`"
)
logger
.
warn
(
"To build the graph from config, use `launch_train_with_config`!"
)
self
.
train
(
config
.
callbacks
,
config
.
monitors
,
config
.
session_creator
,
config
.
session_init
,
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
# create the old trainer when called with TrainConfig
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
TrainConfig
))
\
...
...
@@ -337,20 +359,6 @@ class SingleCostTrainer(TowerTrainer):
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
def
train
(
self
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks
=
callbacks
+
self
.
_internal_callbacks
super
(
SingleCostTrainer
,
self
)
.
train
(
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
@
call_only_once
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
...
...
@@ -375,8 +383,10 @@ class SingleCostTrainer(TowerTrainer):
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
return
self
.
_internal_callbacks
internal_callbacks
=
input_callbacks
+
train_callbacks
for
cb
in
internal_callbacks
:
self
.
_register_callback
(
cb
)
return
internal_callbacks
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment