Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
01486c39
Commit
01486c39
authored
May 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
MultiGPU GAN Trainer
parent
acf8e8f3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
5 deletions
+46
-5
examples/GAN/BEGAN.py
examples/GAN/BEGAN.py
+7
-2
examples/GAN/GAN.py
examples/GAN/GAN.py
+36
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-1
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+2
-0
No files found.
examples/GAN/BEGAN.py
View file @
01486c39
...
@@ -12,7 +12,7 @@ from tensorpack.utils.globvars import globalns as G
...
@@ -12,7 +12,7 @@ from tensorpack.utils.globvars import globalns as G
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
import
tensorflow
as
tf
import
tensorflow
as
tf
from
GAN
import
GANModelDesc
,
GANTrainer
from
GAN
import
GANModelDesc
,
GANTrainer
,
MultiGPUGANTrainer
"""
"""
Boundary Equilibrium GAN.
Boundary Equilibrium GAN.
...
@@ -161,4 +161,9 @@ if __name__ == '__main__':
...
@@ -161,4 +161,9 @@ if __name__ == '__main__':
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
GANTrainer
(
config
)
.
train
()
nr_gpu
=
get_nr_gpu
()
config
.
nr_tower
=
max
(
get_nr_gpu
(),
1
)
if
config
.
nr_tower
==
1
:
GANTrainer
(
config
)
.
train
()
else
:
MultiGPUGANTrainer
(
config
)
.
train
()
examples/GAN/GAN.py
View file @
01486c39
...
@@ -6,7 +6,9 @@
...
@@ -6,7 +6,9 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
import
numpy
as
np
import
time
import
time
from
tensorpack
import
(
FeedfreeTrainerBase
,
QueueInput
,
ModelDesc
,
DataFlow
)
from
tensorpack
import
(
FeedfreeTrainerBase
,
QueueInput
,
ModelDesc
,
DataFlow
,
StagingInputWrapper
,
MultiGPUTrainerBase
,
LeastLoadedDeviceSetter
)
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.summary
import
add_moving_summary
...
@@ -17,7 +19,9 @@ class GANModelDesc(ModelDesc):
...
@@ -17,7 +19,9 @@ class GANModelDesc(ModelDesc):
and same with self.d_vars.
and same with self.d_vars.
"""
"""
self
.
g_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
g_scope
)
self
.
g_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
g_scope
)
assert
self
.
g_vars
self
.
d_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
d_scope
)
self
.
d_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
d_scope
)
assert
self
.
d_vars
def
build_losses
(
self
,
logits_real
,
logits_fake
):
def
build_losses
(
self
,
logits_real
,
logits_fake
):
"""D and G play two-player minimax game with value function V(G,D)
"""D and G play two-player minimax game with value function V(G,D)
...
@@ -56,7 +60,6 @@ class GANModelDesc(ModelDesc):
...
@@ -56,7 +60,6 @@ class GANModelDesc(ModelDesc):
class
GANTrainer
(
FeedfreeTrainerBase
):
class
GANTrainer
(
FeedfreeTrainerBase
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
# TODO design better
self
.
_input_source
=
QueueInput
(
config
.
dataflow
)
self
.
_input_source
=
QueueInput
(
config
.
dataflow
)
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
...
@@ -105,6 +108,37 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
...
@@ -105,6 +108,37 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
self
.
_cnt
+=
1
self
.
_cnt
+=
1
class
MultiGPUGANTrainer
(
MultiGPUTrainerBase
,
FeedfreeTrainerBase
):
def
__init__
(
self
,
config
):
super
(
MultiGPUGANTrainer
,
self
)
.
__init__
(
config
)
self
.
_nr_gpu
=
config
.
nr_tower
assert
self
.
_nr_gpu
>
1
self
.
_raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
self
.
config
.
tower
]
self
.
_input_source
=
StagingInputWrapper
(
QueueInput
(
config
.
dataflow
),
self
.
_raw_devices
)
def
_setup
(
self
):
super
(
MultiGPUGANTrainer
,
self
)
.
_setup
()
devices
=
[
LeastLoadedDeviceSetter
(
d
,
self
.
_raw_devices
)
for
d
in
self
.
_raw_devices
]
def
get_cost
():
self
.
build_train_tower
()
return
[
self
.
model
.
d_loss
,
self
.
model
.
g_loss
]
cost_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
get_cost
,
devices
)
# simply average the cost. might be faster to average the gradients
d_loss
=
tf
.
add_n
([
x
[
0
]
for
x
in
cost_list
])
*
(
1.0
/
self
.
_nr_gpu
)
g_loss
=
tf
.
add_n
([
x
[
1
]
for
x
in
cost_list
])
*
(
1.0
/
self
.
_nr_gpu
)
opt
=
self
.
model
.
get_optimizer
()
# run one d_min after one g_min
self
.
g_min
=
opt
.
minimize
(
g_loss
,
var_list
=
self
.
model
.
g_vars
,
colocate_gradients_with_ops
=
True
,
name
=
'g_op'
)
with
tf
.
control_dependencies
([
self
.
g_min
]):
self
.
d_min
=
opt
.
minimize
(
d_loss
,
var_list
=
self
.
model
.
d_vars
,
colocate_gradients_with_ops
=
True
,
name
=
'd_op'
)
self
.
train_op
=
self
.
d_min
class
RandomZData
(
DataFlow
):
class
RandomZData
(
DataFlow
):
def
__init__
(
self
,
shape
):
def
__init__
(
self
,
shape
):
super
(
RandomZData
,
self
)
.
__init__
()
super
(
RandomZData
,
self
)
.
__init__
()
...
...
tensorpack/train/feedfree.py
View file @
01486c39
...
@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient"""
""" get the cost and gradient"""
self
.
build_train_tower
()
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
# assume single cost
cost
=
self
.
model
.
get_cost
()
# assume single cost
opt
=
self
.
config
.
optimizer
opt
=
self
.
config
.
optimizer
# TODO XXX
# GATE_NONE faster?
# GATE_NONE faster?
grads
=
opt
.
compute_gradients
(
grads
=
opt
.
compute_gradients
(
cost
,
cost
,
...
...
tensorpack/train/input_source.py
View file @
01486c39
...
@@ -393,6 +393,8 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -393,6 +393,8 @@ class StagingInputWrapper(FeedfreeInput):
self
.
_stage_ops
.
append
(
stage
.
put
(
inputs
))
self
.
_stage_ops
.
append
(
stage
.
put
(
inputs
))
self
.
_areas
.
append
(
stage
)
self
.
_areas
.
append
(
stage
)
outputs
=
stage
.
get
()
outputs
=
stage
.
get
()
if
isinstance
(
outputs
,
tf
.
Tensor
):
# when size=1, TF doesn't return a list
outputs
=
[
outputs
]
for
vin
,
vout
in
zip
(
inputs
,
outputs
):
for
vin
,
vout
in
zip
(
inputs
,
outputs
):
vout
.
set_shape
(
vin
.
get_shape
())
vout
.
set_shape
(
vin
.
get_shape
())
self
.
_unstage_ops
.
append
(
outputs
)
self
.
_unstage_ops
.
append
(
outputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment