Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ea60a630
Commit
ea60a630
authored
Feb 05, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Translate data_format and activation to tflayers (#627)
parent
0e5299bb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
67 additions
and
53 deletions
+67
-53
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+8
-9
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+10
-19
tensorpack/models/fc.py
tensorpack/models/fc.py
+3
-3
tensorpack/models/layer_norm.py
tensorpack/models/layer_norm.py
+5
-2
tensorpack/models/pool.py
tensorpack/models/pool.py
+9
-11
tensorpack/models/registry.py
tensorpack/models/registry.py
+14
-0
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+5
-7
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+13
-2
No files found.
tensorpack/models/batch_norm.py
View file @
ea60a630
...
...
@@ -8,6 +8,7 @@ from tensorflow.contrib.framework import add_model_variable
from
tensorflow.python.training
import
moving_averages
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.collection
import
backup_collection
,
restore_collection
...
...
@@ -67,7 +68,8 @@ def reshape_for_bn(param, ndims, chan, data_format):
@
layer_register
()
def
BatchNorm
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
,
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
tf
.
constant_initializer
(
1.0
),
data_format
=
'NHWC'
,
gamma_init
=
tf
.
constant_initializer
(
1.0
),
data_format
=
'channels_last'
,
internal_update
=
False
):
"""
Batch Normalization layer, as described in the paper:
...
...
@@ -109,6 +111,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
x
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
...
...
@@ -181,7 +184,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
@
layer_register
()
def
BatchRenorm
(
x
,
rmax
,
dmax
,
decay
=
0.9
,
epsilon
=
1e-5
,
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
None
,
data_format
=
'NHWC'
):
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
None
,
data_format
=
'channels_last'
):
"""
Batch Renormalization layer, as described in the paper:
`Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
...
...
@@ -210,18 +214,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
if
ndims
==
2
:
data_format
=
'
NHWC
'
# error using NCHW? (see #190)
data_format
=
'
channels_last
'
# error using NCHW? (see #190)
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
shape
[
1
]])
if
data_format
==
'NCHW'
:
n_out
=
shape
[
1
]
else
:
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchRenorm cannot have unknown channels!"
ctx
=
get_current_tower_context
()
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
layer
=
tf
.
layers
.
BatchNormalization
(
axis
=
1
if
data_format
==
'
NCHW
'
else
3
,
axis
=
1
if
data_format
==
'
channels_first
'
else
3
,
momentum
=
decay
,
epsilon
=
epsilon
,
center
=
use_bias
,
scale
=
use_scale
,
renorm
=
True
,
...
...
tensorpack/models/conv2d.py
View file @
ea60a630
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
,
rename_get_variable
from
..tfutils.common
import
get_tf_version_number
from
..utils.argtools
import
shape2d
,
shape4d
from
..utils.argtools
import
shape2d
,
shape4d
,
get_data_format
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
...
...
@@ -15,8 +15,8 @@ __all__ = ['Conv2D', 'Deconv2D']
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
padding
=
'SAME'
,
stride
=
1
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
split
=
1
,
use_bias
=
True
,
data_format
=
'
NHWC
'
,
dilation_rate
=
1
):
activation
=
tf
.
identity
,
split
=
1
,
use_bias
=
True
,
data_format
=
'
channels_last
'
,
dilation_rate
=
1
):
"""
2D convolution on 4D inputs.
...
...
@@ -30,9 +30,7 @@ def Conv2D(x, out_channel, kernel_shape,
split (int): Split channels as used in Alexnet. Defaults to 1 (no split).
W_init: initializer for W. Defaults to `variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias.
data_format (str): 'NHWC' or 'NCHW'.
dilation_rate: (h, w) tuple or a int.
Returns:
...
...
@@ -43,6 +41,7 @@ def Conv2D(x, out_channel, kernel_shape,
* ``W``: weights
* ``b``: bias
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
in_shape
=
x
.
get_shape
()
.
as_list
()
channel_axis
=
3
if
data_format
==
'NHWC'
else
1
in_channel
=
in_shape
[
channel_axis
]
...
...
@@ -79,7 +78,7 @@ def Conv2D(x, out_channel, kernel_shape,
for
i
,
k
in
zip
(
inputs
,
kernels
)]
conv
=
tf
.
concat
(
outputs
,
channel_axis
)
ret
=
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
ret
=
activation
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
ret
.
variables
=
VariableHolder
(
W
=
W
)
if
use_bias
:
ret
.
variables
.
b
=
b
...
...
@@ -90,8 +89,8 @@ def Conv2D(x, out_channel, kernel_shape,
def
Deconv2D
(
x
,
out_channel
,
kernel_shape
,
stride
,
padding
=
'SAME'
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
use_bias
=
True
,
data_format
=
'
NHWC
'
):
activation
=
tf
.
identity
,
use_bias
=
True
,
data_format
=
'
channels_last
'
):
"""
2D deconvolution on 4D inputs.
...
...
@@ -104,7 +103,6 @@ def Deconv2D(x, out_channel, kernel_shape,
padding (str): 'valid' or 'same'. Case insensitive.
W_init: initializer for W. Defaults to `tf.variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias.
Returns:
...
...
@@ -115,13 +113,6 @@ def Deconv2D(x, out_channel, kernel_shape,
* ``W``: weights
* ``b``: bias
"""
in_shape
=
x
.
get_shape
()
.
as_list
()
channel_axis
=
3
if
data_format
==
'NHWC'
else
1
in_channel
=
in_shape
[
channel_axis
]
assert
in_channel
is
not
None
,
"[Deconv2D] Input cannot have unknown channel!"
assert
isinstance
(
out_channel
,
int
),
out_channel
if
W_init
is
None
:
W_init
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
)
if
b_init
is
None
:
...
...
@@ -131,8 +122,8 @@ def Deconv2D(x, out_channel, kernel_shape,
layer
=
tf
.
layers
.
Conv2DTranspose
(
out_channel
,
kernel_shape
,
strides
=
stride
,
padding
=
padding
,
data_format
=
'channels_last'
if
data_format
==
'NHWC'
else
'channels_first'
,
activation
=
lambda
x
:
nl
(
x
,
name
=
'output'
)
,
data_format
=
data_format
,
activation
=
activation
,
use_bias
=
use_bias
,
kernel_initializer
=
W_init
,
bias_initializer
=
b_init
,
...
...
@@ -142,4 +133,4 @@ def Deconv2D(x, out_channel, kernel_shape,
ret
.
variables
=
VariableHolder
(
W
=
layer
.
kernel
)
if
use_bias
:
ret
.
variables
.
b
=
layer
.
bias
return
ret
return
tf
.
identity
(
ret
,
name
=
'output'
)
tensorpack/models/fc.py
View file @
ea60a630
...
...
@@ -14,7 +14,7 @@ __all__ = ['FullyConnected']
@
layer_register
(
log_shape
=
True
)
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
use_bias
=
True
):
activation
=
tf
.
identity
,
use_bias
=
True
):
"""
Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
It is an equivalent of `tf.layers.dense` except for naming conventions.
...
...
@@ -44,7 +44,7 @@ def FullyConnected(x, out_dim,
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
layer
=
tf
.
layers
.
Dense
(
out_dim
,
activation
=
lambda
x
:
nl
(
x
,
name
=
'output'
)
,
use_bias
=
use_bias
,
out_dim
,
activation
=
activation
,
use_bias
=
use_bias
,
kernel_initializer
=
W_init
,
bias_initializer
=
b_init
,
trainable
=
True
)
ret
=
layer
.
apply
(
x
,
scope
=
tf
.
get_variable_scope
())
...
...
@@ -52,4 +52,4 @@ def FullyConnected(x, out_dim,
ret
.
variables
=
VariableHolder
(
W
=
layer
.
kernel
)
if
use_bias
:
ret
.
variables
.
b
=
layer
.
bias
return
ret
return
tf
.
identity
(
ret
,
name
=
'output'
)
tensorpack/models/layer_norm.py
View file @
ea60a630
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
from
..utils.argtools
import
get_data_format
__all__
=
[
'LayerNorm'
,
'InstanceNorm'
]
...
...
@@ -13,7 +14,7 @@ __all__ = ['LayerNorm', 'InstanceNorm']
def
LayerNorm
(
x
,
epsilon
=
1e-5
,
use_bias
=
True
,
use_scale
=
True
,
gamma_init
=
None
,
data_format
=
'
NHWC
'
):
gamma_init
=
None
,
data_format
=
'
channels_last
'
):
"""
Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
...
...
@@ -23,6 +24,7 @@ def LayerNorm(
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
x
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
...
...
@@ -62,7 +64,7 @@ def LayerNorm(
@
layer_register
()
def
InstanceNorm
(
x
,
epsilon
=
1e-5
,
use_affine
=
True
,
gamma_init
=
None
,
data_format
=
'
NHWC
'
):
def
InstanceNorm
(
x
,
epsilon
=
1e-5
,
use_affine
=
True
,
gamma_init
=
None
,
data_format
=
'
channels_last
'
):
"""
Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization
...
...
@@ -73,6 +75,7 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
==
4
,
"Input of InstanceNorm has to be 4D!"
...
...
tensorpack/models/pool.py
View file @
ea60a630
...
...
@@ -7,7 +7,7 @@ import numpy as np
from
.shape_utils
import
StaticDynamicShape
from
.common
import
layer_register
from
..utils.argtools
import
shape2d
from
..utils.argtools
import
shape2d
,
get_data_format
from
._test
import
TestModel
...
...
@@ -16,7 +16,7 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
@
layer_register
(
log_shape
=
True
)
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
,
data_format
=
'
NHWC
'
):
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
,
data_format
=
'
channels_last
'
):
"""
Max Pooling on 4D tensors.
...
...
@@ -31,13 +31,12 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
"""
if
stride
is
None
:
stride
=
shape
ret
=
tf
.
layers
.
max_pooling2d
(
x
,
shape
,
stride
,
padding
,
'channels_last'
if
data_format
==
'NHWC'
else
'channels_first'
)
ret
=
tf
.
layers
.
max_pooling2d
(
x
,
shape
,
stride
,
padding
,
data_format
=
data_format
)
return
tf
.
identity
(
ret
,
name
=
'output'
)
@
layer_register
(
log_shape
=
True
)
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
,
data_format
=
'
NHWC
'
):
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
,
data_format
=
'
channels_last
'
):
"""
Average Pooling on 4D tensors.
...
...
@@ -52,13 +51,12 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
"""
if
stride
is
None
:
stride
=
shape
ret
=
tf
.
layers
.
average_pooling2d
(
x
,
shape
,
stride
,
padding
,
'channels_last'
if
data_format
==
'NHWC'
else
'channels_first'
)
ret
=
tf
.
layers
.
average_pooling2d
(
x
,
shape
,
stride
,
padding
,
data_format
=
data_format
)
return
tf
.
identity
(
ret
,
name
=
'output'
)
@
layer_register
(
log_shape
=
True
)
def
GlobalAvgPooling
(
x
,
data_format
=
'
NHWC
'
):
def
GlobalAvgPooling
(
x
,
data_format
=
'
channels_last
'
):
"""
Global average pooling as in the paper `Network In Network
<http://arxiv.org/abs/1312.4400>`_.
...
...
@@ -69,8 +67,7 @@ def GlobalAvgPooling(x, data_format='NHWC'):
tf.Tensor: a NC tensor named ``output``.
"""
assert
x
.
shape
.
ndims
==
4
assert
data_format
in
[
'NHWC'
,
'NCHW'
]
axis
=
[
1
,
2
]
if
data_format
==
'NHWC'
else
[
2
,
3
]
axis
=
[
1
,
2
]
if
data_format
==
'channels_last'
else
[
2
,
3
]
return
tf
.
reduce_mean
(
x
,
axis
,
name
=
'output'
)
...
...
@@ -90,7 +87,7 @@ def UnPooling2x2ZeroFilled(x):
@
layer_register
(
log_shape
=
True
)
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
,
data_format
=
'
NHWC
'
):
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
,
data_format
=
'
channels_last
'
):
"""
Unpool the input with a fixed matrix to perform kronecker product with.
...
...
@@ -103,6 +100,7 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'):
Returns:
tf.Tensor: a 4D image tensor.
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
shape2d
(
shape
)
output_shape
=
StaticDynamicShape
(
x
)
...
...
tensorpack/models/registry.py
View file @
ea60a630
...
...
@@ -11,6 +11,7 @@ import copy
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.model_utils
import
get_shape_str
from
..utils.argtools
import
get_data_format
from
..utils
import
logger
# make sure each layer is only logged once
...
...
@@ -20,6 +21,18 @@ _LAYER_REGISTRY = {}
__all__
=
[
'layer_register'
]
def
map_tfargs
(
kwargs
):
df
=
kwargs
.
pop
(
'data_format'
,
None
)
if
df
is
not
None
:
df
=
get_data_format
(
df
,
tfmode
=
True
)
kwargs
[
'data_format'
]
=
df
old_nl
=
kwargs
.
pop
(
'nl'
,
None
)
if
old_nl
is
not
None
:
kwargs
[
'activation'
]
=
lambda
x
,
name
=
None
:
old_nl
(
x
,
name
=
name
)
return
kwargs
def
_register
(
name
,
func
):
if
name
in
_LAYER_REGISTRY
:
raise
ValueError
(
"Layer named {} is already registered!"
.
format
(
name
))
...
...
@@ -113,6 +126,7 @@ def layer_register(
if
k
in
actual_args
:
del
actual_args
[
k
]
actual_args
=
map_tfargs
(
actual_args
)
if
name
is
not
None
:
# use scope
with
tf
.
variable_scope
(
name
)
as
scope
:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
...
...
tensorpack/tfutils/argscope.py
View file @
ea60a630
...
...
@@ -4,9 +4,7 @@
from
contextlib
import
contextmanager
from
collections
import
defaultdict
import
inspect
import
copy
import
six
__all__
=
[
'argscope'
,
'get_arg_scope'
]
...
...
@@ -35,14 +33,14 @@ def argscope(layers, **kwargs):
if
not
isinstance
(
layers
,
list
):
layers
=
[
layers
]
def
_check_args_exist
(
l
):
args
=
inspect
.
getargspec
(
l
)
.
args
for
k
,
v
in
six
.
iteritems
(
kwargs
):
assert
k
in
args
,
"No argument {} in {}"
.
format
(
k
,
l
.
__name__
)
#
def _check_args_exist(l):
#
args = inspect.getargspec(l).args
#
for k, v in six.iteritems(kwargs):
#
assert k in args, "No argument {} in {}".format(k, l.__name__)
for
l
in
layers
:
assert
hasattr
(
l
,
'symbolic_function'
),
"{} is not a registered layer"
.
format
(
l
.
__name__
)
_check_args_exist
(
l
.
symbolic_function
)
#
_check_args_exist(l.symbolic_function)
new_scope
=
copy
.
copy
(
get_arg_scope
())
for
l
in
layers
:
...
...
tensorpack/utils/argtools.py
View file @
ea60a630
...
...
@@ -111,7 +111,18 @@ def shape2d(a):
raise
RuntimeError
(
"Illegal shape: {}"
.
format
(
a
))
def
shape4d
(
a
,
data_format
=
'NHWC'
):
def
get_data_format
(
data_format
,
tfmode
=
True
):
if
tfmode
:
dic
=
{
'NCHW'
:
'channels_first'
,
'NHWC'
:
'channels_last'
}
else
:
dic
=
{
'channels_first'
:
'NCHW'
,
'channels_last'
:
'NHWC'
}
ret
=
dic
.
get
(
data_format
,
data_format
)
if
ret
not
in
dic
.
values
():
raise
ValueError
(
"Unknown data_format: {}"
.
format
(
data_format
))
return
ret
def
shape4d
(
a
,
data_format
=
'channels_last'
):
"""
Ensuer a 4D shape, to use with 4D symbolic functions.
...
...
@@ -123,7 +134,7 @@ def shape4d(a, data_format='NHWC'):
or ``[1, 1, a, a]`` depending on data_format.
"""
s2d
=
shape2d
(
a
)
if
data_format
==
'NHWC
'
:
if
get_data_format
(
data_format
)
==
'channels_last
'
:
return
[
1
]
+
s2d
+
[
1
]
else
:
return
[
1
,
1
]
+
s2d
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment