Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
491e0cd9
Commit
491e0cd9
authored
Nov 25, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use globalns in DCGAN
parent
2a0e96e0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
12 deletions
+23
-12
examples/GAN/DCGAN-CelebA.py
examples/GAN/DCGAN-CelebA.py
+12
-11
tensorpack/utils/globvars.py
tensorpack/utils/globvars.py
+11
-1
No files found.
examples/GAN/DCGAN-CelebA.py
View file @
491e0cd9
...
@@ -13,6 +13,7 @@ import cv2
...
@@ -13,6 +13,7 @@ import cv2
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
from
tensorpack.utils.viz
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
,
summary_moving_average
from
tensorpack.tfutils.summary
import
add_moving_summary
,
summary_moving_average
from
tensorpack.utils.globvars
import
globalns
as
CFG
,
use_global_argument
import
tensorpack.tfutils.symbolic_functions
as
symbf
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
GAN
import
GANTrainer
,
RandomZData
,
build_GAN_losses
from
GAN
import
GANTrainer
,
RandomZData
,
build_GAN_losses
...
@@ -27,12 +28,13 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
...
@@ -27,12 +28,13 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
./DCGAN-CelebA.py --load model.tfmodel --sample
./DCGAN-CelebA.py --load model.tfmodel --sample
"""
"""
SHAPE
=
64
CFG
.
SHAPE
=
64
BATCH
=
128
CFG
.
BATCH
=
128
CFG
.
Z_DIM
=
100
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
return
[
InputVar
(
tf
.
float32
,
(
None
,
SHAPE
,
SHAPE
,
3
),
'input'
)
]
return
[
InputVar
(
tf
.
float32
,
(
None
,
CFG
.
SHAPE
,
CFG
.
SHAPE
,
3
),
'input'
)
]
def
generator
(
self
,
z
):
def
generator
(
self
,
z
):
""" return a image generated from z"""
""" return a image generated from z"""
...
@@ -66,8 +68,8 @@ class Model(ModelDesc):
...
@@ -66,8 +68,8 @@ class Model(ModelDesc):
image_pos
=
input_vars
[
0
]
image_pos
=
input_vars
[
0
]
image_pos
=
image_pos
/
128.0
-
1
image_pos
=
image_pos
/
128.0
-
1
z
=
tf
.
random_uniform
([
BATCH
,
100
],
-
1
,
1
,
name
=
'z_train'
)
z
=
tf
.
random_uniform
([
CFG
.
BATCH
,
CFG
.
Z_DIM
],
-
1
,
1
,
name
=
'z_train'
)
z
=
tf
.
placeholder_with_default
(
z
,
[
None
,
100
],
name
=
'z'
)
z
=
tf
.
placeholder_with_default
(
z
,
[
None
,
CFG
.
Z_DIM
],
name
=
'z'
)
with
argscope
([
Conv2D
,
Deconv2D
,
FullyConnected
],
with
argscope
([
Conv2D
,
Deconv2D
,
FullyConnected
],
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
)):
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
)):
...
@@ -85,12 +87,12 @@ class Model(ModelDesc):
...
@@ -85,12 +87,12 @@ class Model(ModelDesc):
self
.
d_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
'discrim/'
)]
self
.
d_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
'discrim/'
)]
def
get_data
():
def
get_data
():
datadir
=
args
.
data
datadir
=
CFG
.
data
imgs
=
glob
.
glob
(
datadir
+
'/*.jpg'
)
imgs
=
glob
.
glob
(
datadir
+
'/*.jpg'
)
ds
=
ImageFromFile
(
imgs
,
channel
=
3
,
shuffle
=
True
)
ds
=
ImageFromFile
(
imgs
,
channel
=
3
,
shuffle
=
True
)
augs
=
[
imgaug
.
CenterCrop
(
1
1
0
),
imgaug
.
Resize
(
64
)
]
augs
=
[
imgaug
.
CenterCrop
(
1
4
0
),
imgaug
.
Resize
(
64
)
]
ds
=
AugmentImageComponent
(
ds
,
augs
)
ds
=
AugmentImageComponent
(
ds
,
augs
)
ds
=
BatchData
(
ds
,
BATCH
)
ds
=
BatchData
(
ds
,
CFG
.
BATCH
)
ds
=
PrefetchDataZMQ
(
ds
,
1
)
ds
=
PrefetchDataZMQ
(
ds
,
1
)
return
ds
return
ds
...
@@ -149,8 +151,8 @@ if __name__ == '__main__':
...
@@ -149,8 +151,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--sample'
,
action
=
'store_true'
,
help
=
'run sampling'
)
parser
.
add_argument
(
'--sample'
,
action
=
'store_true'
,
help
=
'run sampling'
)
parser
.
add_argument
(
'--vec'
,
action
=
'store_true'
,
help
=
'run vec arithmetic demo'
)
parser
.
add_argument
(
'--vec'
,
action
=
'store_true'
,
help
=
'run vec arithmetic demo'
)
parser
.
add_argument
(
'--data'
,
help
=
'`image_align_celeba` directory of the celebA dataset'
)
parser
.
add_argument
(
'--data'
,
help
=
'`image_align_celeba` directory of the celebA dataset'
)
global
args
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
use_global_argument
(
args
)
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
if
args
.
sample
:
if
args
.
sample
:
...
@@ -162,5 +164,4 @@ if __name__ == '__main__':
...
@@ -162,5 +164,4 @@ if __name__ == '__main__':
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
GANTrainer
(
config
)
.
train
()
GANTrainer
(
config
,
g_vs_d
=
1
)
.
train
()
tensorpack/utils/globvars.py
View file @
491e0cd9
...
@@ -4,8 +4,9 @@
...
@@ -4,8 +4,9 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
six
import
six
import
argparse
__all__
=
[
'globalns'
]
__all__
=
[
'globalns'
,
'use_global_argument'
]
if
six
.
PY2
:
if
six
.
PY2
:
class
NS
:
pass
class
NS
:
pass
...
@@ -14,3 +15,12 @@ else:
...
@@ -14,3 +15,12 @@ else:
NS
=
types
.
SimpleNamespace
NS
=
types
.
SimpleNamespace
globalns
=
NS
()
globalns
=
NS
()
def
use_global_argument
(
args
):
"""
Add the content of argparse.Namespace to globalns
:param args: Argument
"""
assert
isinstance
(
args
,
argparse
.
Namespace
),
type
(
args
)
for
k
,
v
in
six
.
iteritems
(
vars
(
args
)):
setattr
(
globalns
,
k
,
v
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment