Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5bd3c395
Commit
5bd3c395
authored
Oct 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[WIP] GAN trainers with new API
parent
2d99afea
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
9 deletions
+13
-9
examples/GAN/BEGAN.py
examples/GAN/BEGAN.py
+7
-6
examples/GAN/GAN.py
examples/GAN/GAN.py
+6
-3
No files found.
examples/GAN/BEGAN.py
View file @
5bd3c395
...
@@ -146,8 +146,6 @@ if __name__ == '__main__':
...
@@ -146,8 +146,6 @@ if __name__ == '__main__':
logger
.
auto_set_dir
()
logger
.
auto_set_dir
()
config
=
TrainConfig
(
config
=
TrainConfig
(
model
=
Model
(),
dataflow
=
DCGAN
.
get_data
(
args
.
data
),
callbacks
=
[
callbacks
=
[
ModelSaver
(),
ModelSaver
(),
StatMonitorParamSetter
(
StatMonitorParamSetter
(
...
@@ -156,9 +154,12 @@ if __name__ == '__main__':
...
@@ -156,9 +154,12 @@ if __name__ == '__main__':
steps_per_epoch
=
500
,
steps_per_epoch
=
500
,
max_epoch
=
400
,
max_epoch
=
400
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
,
nr_tower
=
max
(
get_nr_gpu
(),
1
)
)
)
if
config
.
nr_tower
==
1
:
input
=
QueueInput
(
DCGAN
.
get_data
(
args
.
data
))
GANTrainer
(
config
)
.
train
()
model
=
Model
()
nr_tower
=
max
(
get_nr_gpu
(),
1
)
if
nr_tower
==
1
:
trainer
=
GANTrainer
(
input
,
model
)
else
:
else
:
MultiGPUGANTrainer
(
config
)
.
train
()
trainer
=
MultiGPUGANTrainer
(
nr_tower
,
input
,
model
)
trainer
.
train_with_config
(
config
)
examples/GAN/GAN.py
View file @
5bd3c395
...
@@ -136,12 +136,15 @@ class MultiGPUGANTrainer(TowerTrainer):
...
@@ -136,12 +136,15 @@ class MultiGPUGANTrainer(TowerTrainer):
input
=
StagingInput
(
input
,
list
(
range
(
nr_gpu
)))
input
=
StagingInput
(
input
,
list
(
range
(
nr_gpu
)))
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
def
get_cost
():
def
get_cost
(
*
inputs
):
model
.
build_graph
(
input
.
get_input_tensors
()
)
model
.
build_graph
(
input
s
)
return
[
model
.
d_loss
,
model
.
g_loss
]
return
[
model
.
d_loss
,
model
.
g_loss
]
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_inputs_desc
())
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_inputs_desc
())
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
DataParallelBuilder
.
build_on_towers
(
list
(
range
(
nr_gpu
)),
tower_func
,
devices
)
cost_list
=
DataParallelBuilder
.
build_on_towers
(
list
(
range
(
nr_gpu
)),
lambda
:
tower_func
(
*
input
.
get_input_tensors
()),
devices
)
# simply average the cost. It might get faster to average the gradients
# simply average the cost. It might get faster to average the gradients
with
tf
.
name_scope
(
'optimize'
):
with
tf
.
name_scope
(
'optimize'
):
d_loss
=
tf
.
add_n
([
x
[
0
]
for
x
in
cost_list
])
*
(
1.0
/
nr_gpu
)
d_loss
=
tf
.
add_n
([
x
[
0
]
for
x
in
cost_list
])
*
(
1.0
/
nr_gpu
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment