Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9f790afc
Commit
9f790afc
authored
Sep 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use tf.layers for Deconv implementation
parent
2b027291
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
41 deletions
+37
-41
examples/GAN/ConditionalGAN-mnist.py
examples/GAN/ConditionalGAN-mnist.py
+2
-2
examples/GAN/DCGAN.py
examples/GAN/DCGAN.py
+4
-4
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+2
-2
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+26
-33
tensorpack/models/utils.py
tensorpack/models/utils.py
+3
-0
No files found.
examples/GAN/ConditionalGAN-mnist.py
View file @
9f790afc
...
@@ -42,10 +42,10 @@ class Model(GANModelDesc):
...
@@ -42,10 +42,10 @@ class Model(GANModelDesc):
y
=
tf
.
reshape
(
y
,
[
-
1
,
1
,
1
,
10
])
y
=
tf
.
reshape
(
y
,
[
-
1
,
1
,
1
,
10
])
l
=
tf
.
concat
([
l
,
tf
.
tile
(
y
,
[
1
,
7
,
7
,
1
])],
3
)
l
=
tf
.
concat
([
l
,
tf
.
tile
(
y
,
[
1
,
7
,
7
,
1
])],
3
)
l
=
Deconv2D
(
'deconv1'
,
l
,
[
14
,
14
,
64
*
2
]
,
5
,
2
,
nl
=
BNReLU
)
l
=
Deconv2D
(
'deconv1'
,
l
,
64
*
2
,
5
,
2
,
nl
=
BNReLU
)
l
=
tf
.
concat
([
l
,
tf
.
tile
(
y
,
[
1
,
14
,
14
,
1
])],
3
)
l
=
tf
.
concat
([
l
,
tf
.
tile
(
y
,
[
1
,
14
,
14
,
1
])],
3
)
l
=
Deconv2D
(
'deconv2'
,
l
,
[
28
,
28
,
1
]
,
5
,
2
,
nl
=
tf
.
identity
)
l
=
Deconv2D
(
'deconv2'
,
l
,
1
,
5
,
2
,
nl
=
tf
.
identity
)
l
=
tf
.
nn
.
tanh
(
l
,
name
=
'gen'
)
l
=
tf
.
nn
.
tanh
(
l
,
name
=
'gen'
)
return
l
return
l
...
...
examples/GAN/DCGAN.py
View file @
9f790afc
...
@@ -51,10 +51,10 @@ class Model(GANModelDesc):
...
@@ -51,10 +51,10 @@ class Model(GANModelDesc):
l
=
tf
.
reshape
(
l
,
[
-
1
,
4
,
4
,
nf
*
8
])
l
=
tf
.
reshape
(
l
,
[
-
1
,
4
,
4
,
nf
*
8
])
l
=
BNReLU
(
l
)
l
=
BNReLU
(
l
)
with
argscope
(
Deconv2D
,
nl
=
BNReLU
,
kernel_shape
=
4
,
stride
=
2
):
with
argscope
(
Deconv2D
,
nl
=
BNReLU
,
kernel_shape
=
4
,
stride
=
2
):
l
=
Deconv2D
(
'deconv1'
,
l
,
[
8
,
8
,
nf
*
4
]
)
l
=
Deconv2D
(
'deconv1'
,
l
,
nf
*
4
)
l
=
Deconv2D
(
'deconv2'
,
l
,
[
16
,
16
,
nf
*
2
]
)
l
=
Deconv2D
(
'deconv2'
,
l
,
nf
*
2
)
l
=
Deconv2D
(
'deconv3'
,
l
,
[
32
,
32
,
nf
]
)
l
=
Deconv2D
(
'deconv3'
,
l
,
nf
)
l
=
Deconv2D
(
'deconv4'
,
l
,
[
64
,
64
,
3
]
,
nl
=
tf
.
identity
)
l
=
Deconv2D
(
'deconv4'
,
l
,
3
,
nl
=
tf
.
identity
)
l
=
tf
.
tanh
(
l
,
name
=
'gen'
)
l
=
tf
.
tanh
(
l
,
name
=
'gen'
)
return
l
return
l
...
...
examples/GAN/InfoGAN-mnist.py
View file @
9f790afc
...
@@ -52,8 +52,8 @@ class Model(GANModelDesc):
...
@@ -52,8 +52,8 @@ class Model(GANModelDesc):
l
=
FullyConnected
(
'fc0'
,
z
,
1024
,
nl
=
BNReLU
)
l
=
FullyConnected
(
'fc0'
,
z
,
1024
,
nl
=
BNReLU
)
l
=
FullyConnected
(
'fc1'
,
l
,
128
*
7
*
7
,
nl
=
BNReLU
)
l
=
FullyConnected
(
'fc1'
,
l
,
128
*
7
*
7
,
nl
=
BNReLU
)
l
=
tf
.
reshape
(
l
,
[
-
1
,
7
,
7
,
128
])
l
=
tf
.
reshape
(
l
,
[
-
1
,
7
,
7
,
128
])
l
=
Deconv2D
(
'deconv1'
,
l
,
[
14
,
14
,
64
]
,
4
,
2
,
nl
=
BNReLU
)
l
=
Deconv2D
(
'deconv1'
,
l
,
64
,
4
,
2
,
nl
=
BNReLU
)
l
=
Deconv2D
(
'deconv2'
,
l
,
[
28
,
28
,
1
]
,
4
,
2
,
nl
=
tf
.
identity
)
l
=
Deconv2D
(
'deconv2'
,
l
,
1
,
4
,
2
,
nl
=
tf
.
identity
)
l
=
tf
.
sigmoid
(
l
,
name
=
'gen'
)
l
=
tf
.
sigmoid
(
l
,
name
=
'gen'
)
return
l
return
l
...
...
tensorpack/models/conv2d.py
View file @
9f790afc
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
from
.common
import
layer_register
,
VariableHolder
,
rename_get_variable
from
..utils.argtools
import
shape2d
,
shape4d
from
..utils.argtools
import
shape2d
,
shape4d
from
.
shape_utils
import
StaticDynamicAxis
from
.
.utils.develop
import
log_deprecated
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
...
@@ -80,7 +80,7 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -80,7 +80,7 @@ def Conv2D(x, out_channel, kernel_shape,
@
layer_register
(
log_shape
=
True
)
@
layer_register
(
log_shape
=
True
)
def
Deconv2D
(
x
,
out_
shape
,
kernel_shape
,
def
Deconv2D
(
x
,
out_
channel
,
kernel_shape
,
stride
,
padding
=
'SAME'
,
stride
,
padding
=
'SAME'
,
W_init
=
None
,
b_init
=
None
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
use_bias
=
True
,
nl
=
tf
.
identity
,
use_bias
=
True
,
...
@@ -91,8 +91,7 @@ def Deconv2D(x, out_shape, kernel_shape,
...
@@ -91,8 +91,7 @@ def Deconv2D(x, out_shape, kernel_shape,
Args:
Args:
x (tf.Tensor): a tensor of shape NHWC.
x (tf.Tensor): a tensor of shape NHWC.
Must have known number of channels, but can have other unknown dimensions.
Must have known number of channels, but can have other unknown dimensions.
out_shape: (h, w, channel) tuple, or just a integer channel,
out_channel: the output number of channel.
then (h, w) will be calculated by input_shape * stride
kernel_shape: (h, w) tuple or a int.
kernel_shape: (h, w) tuple or a int.
stride: (h, w) tuple or a int.
stride: (h, w) tuple or a int.
padding (str): 'valid' or 'same'. Case insensitive.
padding (str): 'valid' or 'same'. Case insensitive.
...
@@ -113,47 +112,41 @@ def Deconv2D(x, out_shape, kernel_shape,
...
@@ -113,47 +112,41 @@ def Deconv2D(x, out_shape, kernel_shape,
channel_axis
=
3
if
data_format
==
'NHWC'
else
1
channel_axis
=
3
if
data_format
==
'NHWC'
else
1
in_channel
=
in_shape
[
channel_axis
]
in_channel
=
in_shape
[
channel_axis
]
assert
in_channel
is
not
None
,
"[Deconv2D] Input cannot have unknown channel!"
assert
in_channel
is
not
None
,
"[Deconv2D] Input cannot have unknown channel!"
kernel_shape
=
shape2d
(
kernel_shape
)
stride2d
=
shape2d
(
stride
)
stride4d
=
shape4d
(
stride
,
data_format
=
data_format
)
padding
=
padding
.
upper
()
in_shape_dyn
=
tf
.
shape
(
x
)
out_shape
=
out_channel
if
isinstance
(
out_shape
,
int
):
if
isinstance
(
out_shape
,
int
):
out_channel
=
out_shape
out_channel
=
out_shape
if
data_format
==
'NHWC'
:
shp3_0
=
StaticDynamicAxis
(
in_shape
[
1
],
in_shape_dyn
[
1
])
.
apply
(
lambda
x
:
stride2d
[
0
]
*
x
)
shp3_1
=
StaticDynamicAxis
(
in_shape
[
2
],
in_shape_dyn
[
2
])
.
apply
(
lambda
x
:
stride2d
[
1
]
*
x
)
shp3_dyn
=
[
shp3_0
.
dynamic
,
shp3_1
.
dynamic
,
out_channel
]
shp3_static
=
[
shp3_0
.
static
,
shp3_1
.
static
,
out_channel
]
else
:
shp3_0
=
StaticDynamicAxis
(
in_shape
[
2
],
in_shape_dyn
[
2
])
.
apply
(
lambda
x
:
stride2d
[
0
]
*
x
)
shp3_1
=
StaticDynamicAxis
(
in_shape
[
3
],
in_shape_dyn
[
3
])
.
apply
(
lambda
x
:
stride2d
[
1
]
*
x
)
shp3_dyn
=
[
out_channel
,
shp3_0
.
dynamic
,
shp3_1
.
dynamic
]
shp3_static
=
[
out_channel
,
shp3_0
.
static
,
shp3_1
.
static
]
else
:
else
:
log_deprecated
(
"Deconv2D(out_shape=[...])"
,
"Use an integer 'out_channel' instead!"
,
"2017-11-18"
)
for
k
in
out_shape
:
for
k
in
out_shape
:
if
not
isinstance
(
k
,
int
):
if
not
isinstance
(
k
,
int
):
raise
ValueError
(
"[Deconv2D] out_shape {} is invalid!"
.
format
(
k
))
raise
ValueError
(
"[Deconv2D] out_shape {} is invalid!"
.
format
(
k
))
out_channel
=
out_shape
[
channel_axis
-
1
]
# out_shape doesn't have batch
out_channel
=
out_shape
[
channel_axis
-
1
]
# out_shape doesn't have batch
shp3_static
=
shp3_dyn
=
out_shape
filter_shape
=
kernel_shape
+
[
out_channel
,
in_channel
]
if
W_init
is
None
:
if
W_init
is
None
:
W_init
=
tf
.
contrib
.
layers
.
xavier_initializer_conv2d
()
W_init
=
tf
.
contrib
.
layers
.
xavier_initializer_conv2d
()
if
b_init
is
None
:
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
b_init
=
tf
.
constant_initializer
()
W
=
tf
.
get_variable
(
'W'
,
filter_shape
,
initializer
=
W_init
)
if
use_bias
:
b
=
tf
.
get_variable
(
'b'
,
[
out_channel
],
initializer
=
b_init
)
out_shape_dyn
=
tf
.
stack
([
tf
.
shape
(
x
)[
0
]]
+
shp3_dyn
)
conv
=
tf
.
nn
.
conv2d_transpose
(
x
,
W
,
out_shape_dyn
,
stride4d
,
padding
=
padding
,
data_format
=
data_format
)
conv
.
set_shape
(
tf
.
TensorShape
([
None
]
+
shp3_static
))
ret
=
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
ret
.
variables
=
VariableHolder
(
W
=
W
)
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
layer
=
tf
.
layers
.
Deconv2D
(
out_channel
,
kernel_shape
,
strides
=
stride
,
padding
=
padding
,
data_format
=
'channels_last'
if
data_format
==
'NHWC'
else
'channels_first'
,
activation
=
lambda
x
:
nl
(
x
,
name
=
'output'
),
use_bias
=
use_bias
,
kernel_initializer
=
W_init
,
bias_initializer
=
b_init
,
trainable
=
True
)
ret
=
layer
.
apply
(
x
,
scope
=
tf
.
get_variable_scope
())
# Check that we only supports out_shape = in_shape * stride
out_shape3
=
ret
.
get_shape
()
.
as_list
()[
1
:]
if
not
isinstance
(
out_shape
,
int
):
assert
list
(
out_shape
)
==
out_shape3
,
"{} != {}"
.
format
(
out_shape
,
out_shape3
)
ret
.
variables
=
VariableHolder
(
W
=
layer
.
kernel
)
if
use_bias
:
if
use_bias
:
ret
.
variables
.
b
=
b
ret
.
variables
.
b
=
layer
.
bias
return
ret
return
ret
tensorpack/models/utils.py
View file @
9f790afc
...
@@ -66,5 +66,8 @@ def monkeypatch_tf_layers():
...
@@ -66,5 +66,8 @@ def monkeypatch_tf_layers():
from
tensorflow.python.layers.normalization
import
BatchNormalization
from
tensorflow.python.layers.normalization
import
BatchNormalization
tf
.
layers
.
BatchNormalization
=
BatchNormalization
tf
.
layers
.
BatchNormalization
=
BatchNormalization
from
tensorflow.python.layers.convolutional
import
Deconv2D
tf
.
layers
.
Deconv2D
=
Deconv2D
monkeypatch_tf_layers
()
monkeypatch_tf_layers
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment