Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ea60a630
Commit
ea60a630
authored
Feb 05, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Translate data_format and activation to tflayers (#627)
parent
0e5299bb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
67 additions
and
53 deletions
+67
-53
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+8
-9
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+10
-19
tensorpack/models/fc.py
tensorpack/models/fc.py
+3
-3
tensorpack/models/layer_norm.py
tensorpack/models/layer_norm.py
+5
-2
tensorpack/models/pool.py
tensorpack/models/pool.py
+9
-11
tensorpack/models/registry.py
tensorpack/models/registry.py
+14
-0
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+5
-7
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+13
-2
No files found.
tensorpack/models/batch_norm.py
View file @
ea60a630
...
@@ -8,6 +8,7 @@ from tensorflow.contrib.framework import add_model_variable
...
@@ -8,6 +8,7 @@ from tensorflow.contrib.framework import add_model_variable
from
tensorflow.python.training
import
moving_averages
from
tensorflow.python.training
import
moving_averages
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
..tfutils.collection
import
backup_collection
,
restore_collection
...
@@ -67,7 +68,8 @@ def reshape_for_bn(param, ndims, chan, data_format):
...
@@ -67,7 +68,8 @@ def reshape_for_bn(param, ndims, chan, data_format):
@
layer_register
()
@
layer_register
()
def
BatchNorm
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
,
def
BatchNorm
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
,
use_scale
=
True
,
use_bias
=
True
,
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
tf
.
constant_initializer
(
1.0
),
data_format
=
'NHWC'
,
gamma_init
=
tf
.
constant_initializer
(
1.0
),
data_format
=
'channels_last'
,
internal_update
=
False
):
internal_update
=
False
):
"""
"""
Batch Normalization layer, as described in the paper:
Batch Normalization layer, as described in the paper:
...
@@ -109,6 +111,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
...
@@ -109,6 +111,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in
don't want to fine tune the EMA. EMA will not be updated in
this case.
this case.
"""
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
assert
ndims
in
[
2
,
4
]
...
@@ -181,7 +184,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
...
@@ -181,7 +184,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
@
layer_register
()
@
layer_register
()
def
BatchRenorm
(
x
,
rmax
,
dmax
,
decay
=
0.9
,
epsilon
=
1e-5
,
def
BatchRenorm
(
x
,
rmax
,
dmax
,
decay
=
0.9
,
epsilon
=
1e-5
,
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
None
,
data_format
=
'NHWC'
):
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
None
,
data_format
=
'channels_last'
):
"""
"""
Batch Renormalization layer, as described in the paper:
Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
...
@@ -210,18 +214,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
...
@@ -210,18 +214,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
ndims
=
len
(
shape
)
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
assert
ndims
in
[
2
,
4
]
if
ndims
==
2
:
if
ndims
==
2
:
data_format
=
'
NHWC
'
# error using NCHW? (see #190)
data_format
=
'
channels_last
'
# error using NCHW? (see #190)
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
shape
[
1
]])
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
shape
[
1
]])
if
data_format
==
'NCHW'
:
n_out
=
shape
[
1
]
else
:
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchRenorm cannot have unknown channels!"
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
layer
=
tf
.
layers
.
BatchNormalization
(
layer
=
tf
.
layers
.
BatchNormalization
(
axis
=
1
if
data_format
==
'
NCHW
'
else
3
,
axis
=
1
if
data_format
==
'
channels_first
'
else
3
,
momentum
=
decay
,
epsilon
=
epsilon
,
momentum
=
decay
,
epsilon
=
epsilon
,
center
=
use_bias
,
scale
=
use_scale
,
center
=
use_bias
,
scale
=
use_scale
,
renorm
=
True
,
renorm
=
True
,
...
...
tensorpack/models/conv2d.py
View file @
ea60a630
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
,
rename_get_variable
from
.common
import
layer_register
,
VariableHolder
,
rename_get_variable
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.common
import
get_tf_version_number
from
..utils.argtools
import
shape2d
,
shape4d
from
..utils.argtools
import
shape2d
,
shape4d
,
get_data_format
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
...
@@ -15,8 +15,8 @@ __all__ = ['Conv2D', 'Deconv2D']
...
@@ -15,8 +15,8 @@ __all__ = ['Conv2D', 'Deconv2D']
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
padding
=
'SAME'
,
stride
=
1
,
padding
=
'SAME'
,
stride
=
1
,
W_init
=
None
,
b_init
=
None
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
split
=
1
,
use_bias
=
True
,
activation
=
tf
.
identity
,
split
=
1
,
use_bias
=
True
,
data_format
=
'
NHWC
'
,
dilation_rate
=
1
):
data_format
=
'
channels_last
'
,
dilation_rate
=
1
):
"""
"""
2D convolution on 4D inputs.
2D convolution on 4D inputs.
...
@@ -30,9 +30,7 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -30,9 +30,7 @@ def Conv2D(x, out_channel, kernel_shape,
split (int): Split channels as used in Alexnet. Defaults to 1 (no split).
split (int): Split channels as used in Alexnet. Defaults to 1 (no split).
W_init: initializer for W. Defaults to `variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
W_init: initializer for W. Defaults to `variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero.
b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias.
use_bias (bool): whether to use bias.
data_format (str): 'NHWC' or 'NCHW'.
dilation_rate: (h, w) tuple or a int.
dilation_rate: (h, w) tuple or a int.
Returns:
Returns:
...
@@ -43,6 +41,7 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -43,6 +41,7 @@ def Conv2D(x, out_channel, kernel_shape,
* ``W``: weights
* ``W``: weights
* ``b``: bias
* ``b``: bias
"""
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
in_shape
=
x
.
get_shape
()
.
as_list
()
in_shape
=
x
.
get_shape
()
.
as_list
()
channel_axis
=
3
if
data_format
==
'NHWC'
else
1
channel_axis
=
3
if
data_format
==
'NHWC'
else
1
in_channel
=
in_shape
[
channel_axis
]
in_channel
=
in_shape
[
channel_axis
]
...
@@ -79,7 +78,7 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -79,7 +78,7 @@ def Conv2D(x, out_channel, kernel_shape,
for
i
,
k
in
zip
(
inputs
,
kernels
)]
for
i
,
k
in
zip
(
inputs
,
kernels
)]
conv
=
tf
.
concat
(
outputs
,
channel_axis
)
conv
=
tf
.
concat
(
outputs
,
channel_axis
)
ret
=
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
ret
=
activation
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
ret
.
variables
=
VariableHolder
(
W
=
W
)
ret
.
variables
=
VariableHolder
(
W
=
W
)
if
use_bias
:
if
use_bias
:
ret
.
variables
.
b
=
b
ret
.
variables
.
b
=
b
...
@@ -90,8 +89,8 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -90,8 +89,8 @@ def Conv2D(x, out_channel, kernel_shape,
def
Deconv2D
(
x
,
out_channel
,
kernel_shape
,
def
Deconv2D
(
x
,
out_channel
,
kernel_shape
,
stride
,
padding
=
'SAME'
,
stride
,
padding
=
'SAME'
,
W_init
=
None
,
b_init
=
None
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
use_bias
=
True
,
activation
=
tf
.
identity
,
use_bias
=
True
,
data_format
=
'
NHWC
'
):
data_format
=
'
channels_last
'
):
"""
"""
2D deconvolution on 4D inputs.
2D deconvolution on 4D inputs.
...
@@ -104,7 +103,6 @@ def Deconv2D(x, out_channel, kernel_shape,
...
@@ -104,7 +103,6 @@ def Deconv2D(x, out_channel, kernel_shape,
padding (str): 'valid' or 'same'. Case insensitive.
padding (str): 'valid' or 'same'. Case insensitive.
W_init: initializer for W. Defaults to `tf.variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
W_init: initializer for W. Defaults to `tf.variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero.
b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias.
use_bias (bool): whether to use bias.
Returns:
Returns:
...
@@ -115,13 +113,6 @@ def Deconv2D(x, out_channel, kernel_shape,
...
@@ -115,13 +113,6 @@ def Deconv2D(x, out_channel, kernel_shape,
* ``W``: weights
* ``W``: weights
* ``b``: bias
* ``b``: bias
"""
"""
in_shape
=
x
.
get_shape
()
.
as_list
()
channel_axis
=
3
if
data_format
==
'NHWC'
else
1
in_channel
=
in_shape
[
channel_axis
]
assert
in_channel
is
not
None
,
"[Deconv2D] Input cannot have unknown channel!"
assert
isinstance
(
out_channel
,
int
),
out_channel
if
W_init
is
None
:
if
W_init
is
None
:
W_init
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
)
W_init
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
)
if
b_init
is
None
:
if
b_init
is
None
:
...
@@ -131,8 +122,8 @@ def Deconv2D(x, out_channel, kernel_shape,
...
@@ -131,8 +122,8 @@ def Deconv2D(x, out_channel, kernel_shape,
layer
=
tf
.
layers
.
Conv2DTranspose
(
layer
=
tf
.
layers
.
Conv2DTranspose
(
out_channel
,
kernel_shape
,
out_channel
,
kernel_shape
,
strides
=
stride
,
padding
=
padding
,
strides
=
stride
,
padding
=
padding
,
data_format
=
'channels_last'
if
data_format
==
'NHWC'
else
'channels_first'
,
data_format
=
data_format
,
activation
=
lambda
x
:
nl
(
x
,
name
=
'output'
)
,
activation
=
activation
,
use_bias
=
use_bias
,
use_bias
=
use_bias
,
kernel_initializer
=
W_init
,
kernel_initializer
=
W_init
,
bias_initializer
=
b_init
,
bias_initializer
=
b_init
,
...
@@ -142,4 +133,4 @@ def Deconv2D(x, out_channel, kernel_shape,
...
@@ -142,4 +133,4 @@ def Deconv2D(x, out_channel, kernel_shape,
ret
.
variables
=
VariableHolder
(
W
=
layer
.
kernel
)
ret
.
variables
=
VariableHolder
(
W
=
layer
.
kernel
)
if
use_bias
:
if
use_bias
:
ret
.
variables
.
b
=
layer
.
bias
ret
.
variables
.
b
=
layer
.
bias
return
ret
return
tf
.
identity
(
ret
,
name
=
'output'
)
tensorpack/models/fc.py
View file @
ea60a630
...
@@ -14,7 +14,7 @@ __all__ = ['FullyConnected']
...
@@ -14,7 +14,7 @@ __all__ = ['FullyConnected']
@
layer_register
(
log_shape
=
True
)
@
layer_register
(
log_shape
=
True
)
def
FullyConnected
(
x
,
out_dim
,
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
use_bias
=
True
):
activation
=
tf
.
identity
,
use_bias
=
True
):
"""
"""
Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
It is an equivalent of `tf.layers.dense` except for naming conventions.
It is an equivalent of `tf.layers.dense` except for naming conventions.
...
@@ -44,7 +44,7 @@ def FullyConnected(x, out_dim,
...
@@ -44,7 +44,7 @@ def FullyConnected(x, out_dim,
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
layer
=
tf
.
layers
.
Dense
(
layer
=
tf
.
layers
.
Dense
(
out_dim
,
activation
=
lambda
x
:
nl
(
x
,
name
=
'output'
)
,
use_bias
=
use_bias
,
out_dim
,
activation
=
activation
,
use_bias
=
use_bias
,
kernel_initializer
=
W_init
,
bias_initializer
=
b_init
,
kernel_initializer
=
W_init
,
bias_initializer
=
b_init
,
trainable
=
True
)
trainable
=
True
)
ret
=
layer
.
apply
(
x
,
scope
=
tf
.
get_variable_scope
())
ret
=
layer
.
apply
(
x
,
scope
=
tf
.
get_variable_scope
())
...
@@ -52,4 +52,4 @@ def FullyConnected(x, out_dim,
...
@@ -52,4 +52,4 @@ def FullyConnected(x, out_dim,
ret
.
variables
=
VariableHolder
(
W
=
layer
.
kernel
)
ret
.
variables
=
VariableHolder
(
W
=
layer
.
kernel
)
if
use_bias
:
if
use_bias
:
ret
.
variables
.
b
=
layer
.
bias
ret
.
variables
.
b
=
layer
.
bias
return
ret
return
tf
.
identity
(
ret
,
name
=
'output'
)
tensorpack/models/layer_norm.py
View file @
ea60a630
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
from
.common
import
layer_register
,
VariableHolder
from
..utils.argtools
import
get_data_format
__all__
=
[
'LayerNorm'
,
'InstanceNorm'
]
__all__
=
[
'LayerNorm'
,
'InstanceNorm'
]
...
@@ -13,7 +14,7 @@ __all__ = ['LayerNorm', 'InstanceNorm']
...
@@ -13,7 +14,7 @@ __all__ = ['LayerNorm', 'InstanceNorm']
def
LayerNorm
(
def
LayerNorm
(
x
,
epsilon
=
1e-5
,
x
,
epsilon
=
1e-5
,
use_bias
=
True
,
use_scale
=
True
,
use_bias
=
True
,
use_scale
=
True
,
gamma_init
=
None
,
data_format
=
'
NHWC
'
):
gamma_init
=
None
,
data_format
=
'
channels_last
'
):
"""
"""
Layer Normalization layer, as described in the paper:
Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
...
@@ -23,6 +24,7 @@ def LayerNorm(
...
@@ -23,6 +24,7 @@ def LayerNorm(
epsilon (float): epsilon to avoid divide-by-zero.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
"""
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
assert
ndims
in
[
2
,
4
]
...
@@ -62,7 +64,7 @@ def LayerNorm(
...
@@ -62,7 +64,7 @@ def LayerNorm(
@
layer_register
()
@
layer_register
()
def
InstanceNorm
(
x
,
epsilon
=
1e-5
,
use_affine
=
True
,
gamma_init
=
None
,
data_format
=
'
NHWC
'
):
def
InstanceNorm
(
x
,
epsilon
=
1e-5
,
use_affine
=
True
,
gamma_init
=
None
,
data_format
=
'
channels_last
'
):
"""
"""
Instance Normalization, as in the paper:
Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization
`Instance Normalization: The Missing Ingredient for Fast Stylization
...
@@ -73,6 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
...
@@ -73,6 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
epsilon (float): avoid divide-by-zero
epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation
use_affine (bool): whether to apply learnable affine transformation
"""
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
==
4
,
"Input of InstanceNorm has to be 4D!"
assert
len
(
shape
)
==
4
,
"Input of InstanceNorm has to be 4D!"
...
...
tensorpack/models/pool.py
View file @
ea60a630
...
@@ -7,7 +7,7 @@ import numpy as np
...
@@ -7,7 +7,7 @@ import numpy as np
from
.shape_utils
import
StaticDynamicShape
from
.shape_utils
import
StaticDynamicShape
from
.common
import
layer_register
from
.common
import
layer_register
from
..utils.argtools
import
shape2d
from
..utils.argtools
import
shape2d
,
get_data_format
from
._test
import
TestModel
from
._test
import
TestModel
...
@@ -16,7 +16,7 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
...
@@ -16,7 +16,7 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
@
layer_register
(
log_shape
=
True
)
@
layer_register
(
log_shape
=
True
)
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
,
data_format
=
'
NHWC
'
):
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
,
data_format
=
'
channels_last
'
):
"""
"""
Max Pooling on 4D tensors.
Max Pooling on 4D tensors.
...
@@ -31,13 +31,12 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
...
@@ -31,13 +31,12 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
"""
"""
if
stride
is
None
:
if
stride
is
None
:
stride
=
shape
stride
=
shape
ret
=
tf
.
layers
.
max_pooling2d
(
x
,
shape
,
stride
,
padding
,
ret
=
tf
.
layers
.
max_pooling2d
(
x
,
shape
,
stride
,
padding
,
data_format
=
data_format
)
'channels_last'
if
data_format
==
'NHWC'
else
'channels_first'
)
return
tf
.
identity
(
ret
,
name
=
'output'
)
return
tf
.
identity
(
ret
,
name
=
'output'
)
@
layer_register
(
log_shape
=
True
)
@
layer_register
(
log_shape
=
True
)
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
,
data_format
=
'
NHWC
'
):
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
,
data_format
=
'
channels_last
'
):
"""
"""
Average Pooling on 4D tensors.
Average Pooling on 4D tensors.
...
@@ -52,13 +51,12 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
...
@@ -52,13 +51,12 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
"""
"""
if
stride
is
None
:
if
stride
is
None
:
stride
=
shape
stride
=
shape
ret
=
tf
.
layers
.
average_pooling2d
(
x
,
shape
,
stride
,
padding
,
ret
=
tf
.
layers
.
average_pooling2d
(
x
,
shape
,
stride
,
padding
,
data_format
=
data_format
)
'channels_last'
if
data_format
==
'NHWC'
else
'channels_first'
)
return
tf
.
identity
(
ret
,
name
=
'output'
)
return
tf
.
identity
(
ret
,
name
=
'output'
)
@
layer_register
(
log_shape
=
True
)
@
layer_register
(
log_shape
=
True
)
def
GlobalAvgPooling
(
x
,
data_format
=
'
NHWC
'
):
def
GlobalAvgPooling
(
x
,
data_format
=
'
channels_last
'
):
"""
"""
Global average pooling as in the paper `Network In Network
Global average pooling as in the paper `Network In Network
<http://arxiv.org/abs/1312.4400>`_.
<http://arxiv.org/abs/1312.4400>`_.
...
@@ -69,8 +67,7 @@ def GlobalAvgPooling(x, data_format='NHWC'):
...
@@ -69,8 +67,7 @@ def GlobalAvgPooling(x, data_format='NHWC'):
tf.Tensor: a NC tensor named ``output``.
tf.Tensor: a NC tensor named ``output``.
"""
"""
assert
x
.
shape
.
ndims
==
4
assert
x
.
shape
.
ndims
==
4
assert
data_format
in
[
'NHWC'
,
'NCHW'
]
axis
=
[
1
,
2
]
if
data_format
==
'channels_last'
else
[
2
,
3
]
axis
=
[
1
,
2
]
if
data_format
==
'NHWC'
else
[
2
,
3
]
return
tf
.
reduce_mean
(
x
,
axis
,
name
=
'output'
)
return
tf
.
reduce_mean
(
x
,
axis
,
name
=
'output'
)
...
@@ -90,7 +87,7 @@ def UnPooling2x2ZeroFilled(x):
...
@@ -90,7 +87,7 @@ def UnPooling2x2ZeroFilled(x):
@
layer_register
(
log_shape
=
True
)
@
layer_register
(
log_shape
=
True
)
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
,
data_format
=
'
NHWC
'
):
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
,
data_format
=
'
channels_last
'
):
"""
"""
Unpool the input with a fixed matrix to perform kronecker product with.
Unpool the input with a fixed matrix to perform kronecker product with.
...
@@ -103,6 +100,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'):
...
@@ -103,6 +100,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'):
Returns:
Returns:
tf.Tensor: a 4D image tensor.
tf.Tensor: a 4D image tensor.
"""
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
shape2d
(
shape
)
shape
=
shape2d
(
shape
)
output_shape
=
StaticDynamicShape
(
x
)
output_shape
=
StaticDynamicShape
(
x
)
...
...
tensorpack/models/registry.py
View file @
ea60a630
...
@@ -11,6 +11,7 @@ import copy
...
@@ -11,6 +11,7 @@ import copy
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.model_utils
import
get_shape_str
from
..tfutils.model_utils
import
get_shape_str
from
..utils.argtools
import
get_data_format
from
..utils
import
logger
from
..utils
import
logger
# make sure each layer is only logged once
# make sure each layer is only logged once
...
@@ -20,6 +21,18 @@ _LAYER_REGISTRY = {}
...
@@ -20,6 +21,18 @@ _LAYER_REGISTRY = {}
__all__
=
[
'layer_register'
]
__all__
=
[
'layer_register'
]
def
map_tfargs
(
kwargs
):
df
=
kwargs
.
pop
(
'data_format'
,
None
)
if
df
is
not
None
:
df
=
get_data_format
(
df
,
tfmode
=
True
)
kwargs
[
'data_format'
]
=
df
old_nl
=
kwargs
.
pop
(
'nl'
,
None
)
if
old_nl
is
not
None
:
kwargs
[
'activation'
]
=
lambda
x
,
name
=
None
:
old_nl
(
x
,
name
=
name
)
return
kwargs
def
_register
(
name
,
func
):
def
_register
(
name
,
func
):
if
name
in
_LAYER_REGISTRY
:
if
name
in
_LAYER_REGISTRY
:
raise
ValueError
(
"Layer named {} is already registered!"
.
format
(
name
))
raise
ValueError
(
"Layer named {} is already registered!"
.
format
(
name
))
...
@@ -113,6 +126,7 @@ def layer_register(
...
@@ -113,6 +126,7 @@ def layer_register(
if
k
in
actual_args
:
if
k
in
actual_args
:
del
actual_args
[
k
]
del
actual_args
[
k
]
actual_args
=
map_tfargs
(
actual_args
)
if
name
is
not
None
:
# use scope
if
name
is
not
None
:
# use scope
with
tf
.
variable_scope
(
name
)
as
scope
:
with
tf
.
variable_scope
(
name
)
as
scope
:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
# this name is only used to surpress logging, doesn't hurt to do some heuristics
...
...
tensorpack/tfutils/argscope.py
View file @
ea60a630
...
@@ -4,9 +4,7 @@
...
@@ -4,9 +4,7 @@
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
collections
import
defaultdict
from
collections
import
defaultdict
import
inspect
import
copy
import
copy
import
six
__all__
=
[
'argscope'
,
'get_arg_scope'
]
__all__
=
[
'argscope'
,
'get_arg_scope'
]
...
@@ -35,14 +33,14 @@ def argscope(layers, **kwargs):
...
@@ -35,14 +33,14 @@ def argscope(layers, **kwargs):
if
not
isinstance
(
layers
,
list
):
if
not
isinstance
(
layers
,
list
):
layers
=
[
layers
]
layers
=
[
layers
]
def
_check_args_exist
(
l
):
#
def _check_args_exist(l):
args
=
inspect
.
getargspec
(
l
)
.
args
#
args = inspect.getargspec(l).args
for
k
,
v
in
six
.
iteritems
(
kwargs
):
#
for k, v in six.iteritems(kwargs):
assert
k
in
args
,
"No argument {} in {}"
.
format
(
k
,
l
.
__name__
)
#
assert k in args, "No argument {} in {}".format(k, l.__name__)
for
l
in
layers
:
for
l
in
layers
:
assert
hasattr
(
l
,
'symbolic_function'
),
"{} is not a registered layer"
.
format
(
l
.
__name__
)
assert
hasattr
(
l
,
'symbolic_function'
),
"{} is not a registered layer"
.
format
(
l
.
__name__
)
_check_args_exist
(
l
.
symbolic_function
)
#
_check_args_exist(l.symbolic_function)
new_scope
=
copy
.
copy
(
get_arg_scope
())
new_scope
=
copy
.
copy
(
get_arg_scope
())
for
l
in
layers
:
for
l
in
layers
:
...
...
tensorpack/utils/argtools.py
View file @
ea60a630
...
@@ -111,7 +111,18 @@ def shape2d(a):
...
@@ -111,7 +111,18 @@ def shape2d(a):
raise
RuntimeError
(
"Illegal shape: {}"
.
format
(
a
))
raise
RuntimeError
(
"Illegal shape: {}"
.
format
(
a
))
def
shape4d
(
a
,
data_format
=
'NHWC'
):
def
get_data_format
(
data_format
,
tfmode
=
True
):
if
tfmode
:
dic
=
{
'NCHW'
:
'channels_first'
,
'NHWC'
:
'channels_last'
}
else
:
dic
=
{
'channels_first'
:
'NCHW'
,
'channels_last'
:
'NHWC'
}
ret
=
dic
.
get
(
data_format
,
data_format
)
if
ret
not
in
dic
.
values
():
raise
ValueError
(
"Unknown data_format: {}"
.
format
(
data_format
))
return
ret
def
shape4d
(
a
,
data_format
=
'channels_last'
):
"""
"""
Ensuer a 4D shape, to use with 4D symbolic functions.
Ensuer a 4D shape, to use with 4D symbolic functions.
...
@@ -123,7 +134,7 @@ def shape4d(a, data_format='NHWC'):
...
@@ -123,7 +134,7 @@ def shape4d(a, data_format='NHWC'):
or ``[1, 1, a, a]`` depending on data_format.
or ``[1, 1, a, a]`` depending on data_format.
"""
"""
s2d
=
shape2d
(
a
)
s2d
=
shape2d
(
a
)
if
data_format
==
'NHWC
'
:
if
get_data_format
(
data_format
)
==
'channels_last
'
:
return
[
1
]
+
s2d
+
[
1
]
return
[
1
]
+
s2d
+
[
1
]
else
:
else
:
return
[
1
,
1
]
+
s2d
return
[
1
,
1
]
+
s2d
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment