Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8156810d
Commit
8156810d
authored
May 27, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
misc small fixes
parent
65449110
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
10 additions
and
14 deletions
+10
-14
docs/_static/build_toc_group.js
docs/_static/build_toc_group.js
+1
-1
examples/GAN/BEGAN.py
examples/GAN/BEGAN.py
+2
-6
examples/GAN/GAN.py
examples/GAN/GAN.py
+2
-2
examples/GAN/Improved-WGAN.py
examples/GAN/Improved-WGAN.py
+1
-1
examples/GAN/README.md
examples/GAN/README.md
+1
-0
scripts/dump-model-params.py
scripts/dump-model-params.py
+3
-4
No files found.
docs/_static/build_toc_group.js
View file @
8156810d
...
...
@@ -17,7 +17,7 @@ $(function (){
if
(
fullname
.
startsWith
(
'
tensorpack.
'
))
fullname
=
fullname
.
substr
(
11
);
if
(
fullname
==
"
tensorpack.
dataflow.MultiProcessMapData
"
)
{
if
(
fullname
==
"
dataflow.MultiProcessMapData
"
)
{
groupName
=
"
parallel_map
"
;
}
...
...
examples/GAN/BEGAN.py
View file @
8156810d
...
...
@@ -11,7 +11,7 @@ from tensorpack.tfutils.summary import add_moving_summary
from
tensorpack.utils.gpu
import
get_num_gpu
import
DCGAN
from
GAN
import
GANModelDesc
,
GANTrainer
,
MultiGPUGANTrainer
from
GAN
import
GANModelDesc
,
GANTrainer
"""
Boundary Equilibrium GAN.
...
...
@@ -139,11 +139,7 @@ if __name__ == '__main__':
input
=
QueueInput
(
DCGAN
.
get_data
())
model
=
Model
()
nr_tower
=
max
(
get_num_gpu
(),
1
)
if
nr_tower
==
1
:
trainer
=
GANTrainer
(
input
,
model
)
else
:
trainer
=
MultiGPUGANTrainer
(
nr_tower
,
input
,
model
)
trainer
=
GANTrainer
(
input
,
model
,
num_gpu
=
nr_tower
)
trainer
.
train_with_defaults
(
callbacks
=
[
ModelSaver
(),
...
...
examples/GAN/GAN.py
View file @
8156810d
...
...
@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input_signature
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
inputs
())
with
TowerContext
(
''
,
is_training
=
True
):
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
...
...
@@ -167,7 +167,7 @@ class SeparateGANTrainer(TowerTrainer):
self
.
register_callback
(
cbs
)
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input_signature
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
inputs
())
with
TowerContext
(
''
,
is_training
=
True
),
\
argscope
(
BatchNorm
,
ema_update
=
'internal'
):
# should not hook the EMA updates to both train_op, it will hurt training speed.
...
...
examples/GAN/Improved-WGAN.py
View file @
8156810d
...
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
logger
.
auto_set_dir
()
SeparateGANTrainer
(
QueueInput
(
DCGAN
.
get_data
()),
M
,
g_period
=
6
)
.
train_with_defaults
(
M
,
g_period
=
5
)
.
train_with_defaults
(
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
300
,
max_epoch
=
200
,
...
...
examples/GAN/README.md
View file @
8156810d
...
...
@@ -63,6 +63,7 @@ Train a simple GAN on mnist, conditioned on the class labels.
## [WGAN.py](WGAN.py), [Improved-WGAN.py](Improved-WGAN.py), [BEGAN.py](BEGAN.py)
These variants are implemented by some small modifications on top of DCGAN.py.
BEGAN has the best visual quality among them.
Some BEGAN samples:

...
...
scripts/dump-model-params.py
View file @
8156810d
...
...
@@ -3,7 +3,6 @@
# File: dump-model-params.py
import
argparse
import
sys
import
numpy
as
np
import
os
import
six
...
...
@@ -34,9 +33,9 @@ def _import_external_ops(message):
pass
else
:
_validate_and_load_nccl_so
()
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
# noqa
else
:
from
tensorflow.python.ops
import
gen_nccl_ops
from
tensorflow.python.ops
import
gen_nccl_ops
# noqa
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment