Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
07e464d8
You need to sign in or sign up before continuing.
Commit
07e464d8
authored
Mar 08, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
standarize arg names in LayerNorm/InstanceNorm
parent
2ff9a5f4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
25 deletions
+49
-25
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+2
-0
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-2
tensorpack/models/layer_norm.py
tensorpack/models/layer_norm.py
+45
-23
No files found.
tensorpack/input_source/input_source.py
View file @
07e464d8
...
@@ -471,6 +471,8 @@ class TFDatasetInput(FeedfreeInput):
...
@@ -471,6 +471,8 @@ class TFDatasetInput(FeedfreeInput):
self
.
_spec
=
input_signature
self
.
_spec
=
input_signature
if
self
.
_dataset
is
not
None
:
if
self
.
_dataset
is
not
None
:
types
=
self
.
_dataset
.
output_types
types
=
self
.
_dataset
.
output_types
if
len
(
types
)
==
1
:
types
=
(
types
,)
spec_types
=
tuple
(
k
.
dtype
for
k
in
input_signature
)
spec_types
=
tuple
(
k
.
dtype
for
k
in
input_signature
)
assert
len
(
types
)
==
len
(
spec_types
),
\
assert
len
(
types
)
==
len
(
spec_types
),
\
"Dataset and input signature have different length! {} != {}"
.
format
(
"Dataset and input signature have different length! {} != {}"
.
format
(
...
...
tensorpack/models/batch_norm.py
View file @
07e464d8
...
@@ -71,7 +71,7 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
...
@@ -71,7 +71,7 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
'use_local_stat'
:
'training'
'use_local_stat'
:
'training'
})
})
@
disable_autograph
()
@
disable_autograph
()
def
BatchNorm
(
inputs
,
axis
=
None
,
training
=
None
,
momentum
=
0.9
,
epsilon
=
1e-5
,
def
BatchNorm
(
inputs
,
axis
=
None
,
*
,
training
=
None
,
momentum
=
0.9
,
epsilon
=
1e-5
,
center
=
True
,
scale
=
True
,
center
=
True
,
scale
=
True
,
beta_initializer
=
tf
.
zeros_initializer
(),
beta_initializer
=
tf
.
zeros_initializer
(),
gamma_initializer
=
tf
.
ones_initializer
(),
gamma_initializer
=
tf
.
ones_initializer
(),
...
@@ -376,7 +376,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -376,7 +376,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
'gamma_init'
:
'gamma_initializer'
,
'gamma_init'
:
'gamma_initializer'
,
'decay'
:
'momentum'
'decay'
:
'momentum'
})
})
def
BatchRenorm
(
x
,
rmax
,
dmax
,
momentum
=
0.9
,
epsilon
=
1e-5
,
def
BatchRenorm
(
x
,
rmax
,
dmax
,
*
,
momentum
=
0.9
,
epsilon
=
1e-5
,
center
=
True
,
scale
=
True
,
gamma_initializer
=
None
,
center
=
True
,
scale
=
True
,
gamma_initializer
=
None
,
data_format
=
'channels_last'
):
data_format
=
'channels_last'
):
"""
"""
...
...
tensorpack/models/layer_norm.py
View file @
07e464d8
...
@@ -5,16 +5,26 @@
...
@@ -5,16 +5,26 @@
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..compat
import
tfv1
as
tf
# this should be avoided first in model code
from
..utils.argtools
import
get_data_format
from
..utils.argtools
import
get_data_format
from
..utils.develop
import
log_deprecated
from
.common
import
VariableHolder
,
layer_register
from
.common
import
VariableHolder
,
layer_register
from
.tflayer
import
convert_to_tflayer_args
__all__
=
[
'LayerNorm'
,
'InstanceNorm'
]
__all__
=
[
'LayerNorm'
,
'InstanceNorm'
]
@
layer_register
()
@
layer_register
()
@
convert_to_tflayer_args
(
args_names
=
[],
name_mapping
=
{
'use_bias'
:
'center'
,
'use_scale'
:
'scale'
,
'gamma_init'
:
'gamma_initializer'
,
})
def
LayerNorm
(
def
LayerNorm
(
x
,
epsilon
=
1e-5
,
x
,
epsilon
=
1e-5
,
*
,
use_bias
=
True
,
use_scale
=
True
,
center
=
True
,
scale
=
True
,
gamma_init
=
None
,
data_format
=
'channels_last'
):
gamma_initializer
=
tf
.
ones_initializer
(),
data_format
=
'channels_last'
):
"""
"""
Layer Normalization layer, as described in the paper:
Layer Normalization layer, as described in the paper:
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
...
@@ -22,7 +32,7 @@ def LayerNorm(
...
@@ -22,7 +32,7 @@ def LayerNorm(
Args:
Args:
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
epsilon (float): epsilon to avoid divide-by-zero.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias
(bool): whether to use the extra affine transformation or not.
center, scale
(bool): whether to use the extra affine transformation or not.
"""
"""
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
...
@@ -40,15 +50,13 @@ def LayerNorm(
...
@@ -40,15 +50,13 @@ def LayerNorm(
if
ndims
==
2
:
if
ndims
==
2
:
new_shape
=
[
1
,
chan
]
new_shape
=
[
1
,
chan
]
if
use_bias
:
if
center
:
beta
=
tf
.
get_variable
(
'beta'
,
[
chan
],
initializer
=
tf
.
constant_initializer
())
beta
=
tf
.
get_variable
(
'beta'
,
[
chan
],
initializer
=
tf
.
constant_initializer
())
beta
=
tf
.
reshape
(
beta
,
new_shape
)
beta
=
tf
.
reshape
(
beta
,
new_shape
)
else
:
else
:
beta
=
tf
.
zeros
([
1
]
*
ndims
,
name
=
'beta'
)
beta
=
tf
.
zeros
([
1
]
*
ndims
,
name
=
'beta'
)
if
use_scale
:
if
scale
:
if
gamma_init
is
None
:
gamma
=
tf
.
get_variable
(
'gamma'
,
[
chan
],
initializer
=
gamma_initializer
)
gamma_init
=
tf
.
constant_initializer
(
1.0
)
gamma
=
tf
.
get_variable
(
'gamma'
,
[
chan
],
initializer
=
gamma_init
)
gamma
=
tf
.
reshape
(
gamma
,
new_shape
)
gamma
=
tf
.
reshape
(
gamma
,
new_shape
)
else
:
else
:
gamma
=
tf
.
ones
([
1
]
*
ndims
,
name
=
'gamma'
)
gamma
=
tf
.
ones
([
1
]
*
ndims
,
name
=
'gamma'
)
...
@@ -56,15 +64,22 @@ def LayerNorm(
...
@@ -56,15 +64,22 @@ def LayerNorm(
ret
=
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
ret
=
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
()
vh
=
ret
.
variables
=
VariableHolder
()
if
use_
scale
:
if
scale
:
vh
.
gamma
=
gamma
vh
.
gamma
=
gamma
if
use_bias
:
if
center
:
vh
.
beta
=
beta
vh
.
beta
=
beta
return
ret
return
ret
@
layer_register
()
@
layer_register
()
def
InstanceNorm
(
x
,
epsilon
=
1e-5
,
use_affine
=
True
,
gamma_init
=
None
,
data_format
=
'channels_last'
):
@
convert_to_tflayer_args
(
args_names
=
[],
name_mapping
=
{
'gamma_init'
:
'gamma_initializer'
,
})
def
InstanceNorm
(
x
,
epsilon
=
1e-5
,
*
,
center
=
True
,
scale
=
True
,
gamma_initializer
=
tf
.
ones_initializer
(),
data_format
=
'channels_last'
,
use_affine
=
None
):
"""
"""
Instance Normalization, as in the paper:
Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization
`Instance Normalization: The Missing Ingredient for Fast Stylization
...
@@ -73,12 +88,17 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
...
@@ -73,12 +88,17 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
Args:
Args:
x (tf.Tensor): a 4D tensor.
x (tf.Tensor): a 4D tensor.
epsilon (float): avoid divide-by-zero
epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation
center, scale (bool): whether to use the extra affine transformation or not.
use_affine: deprecated. Don't use.
"""
"""
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
==
4
,
"Input of InstanceNorm has to be 4D!"
assert
len
(
shape
)
==
4
,
"Input of InstanceNorm has to be 4D!"
if
use_affine
is
not
None
:
log_deprecated
(
"InstanceNorm(use_affine=)"
,
"Use center= or scale= instead!"
,
"2020-06-01"
)
center
=
scale
=
use_affine
if
data_format
==
'NHWC'
:
if
data_format
==
'NHWC'
:
axis
=
[
1
,
2
]
axis
=
[
1
,
2
]
ch
=
shape
[
3
]
ch
=
shape
[
3
]
...
@@ -91,19 +111,21 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
...
@@ -91,19 +111,21 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
mean
,
var
=
tf
.
nn
.
moments
(
x
,
axis
,
keep_dims
=
True
)
mean
,
var
=
tf
.
nn
.
moments
(
x
,
axis
,
keep_dims
=
True
)
if
not
use_affine
:
if
center
:
return
tf
.
divide
(
x
-
mean
,
tf
.
sqrt
(
var
+
epsilon
),
name
=
'output'
)
beta
=
tf
.
get_variable
(
'beta'
,
[
ch
],
initializer
=
tf
.
constant_initializer
())
beta
=
tf
.
get_variable
(
'beta'
,
[
ch
],
initializer
=
tf
.
constant_initializer
())
beta
=
tf
.
reshape
(
beta
,
new_shape
)
beta
=
tf
.
reshape
(
beta
,
new_shape
)
if
gamma_init
is
None
:
else
:
gamma_init
=
tf
.
constant_initializer
(
1.0
)
beta
=
tf
.
zeros
([
1
,
1
,
1
,
1
],
name
=
'beta'
,
dtype
=
x
.
dtype
)
gamma
=
tf
.
get_variable
(
'gamma'
,
[
ch
],
initializer
=
gamma_init
)
if
scale
:
gamma
=
tf
.
get_variable
(
'gamma'
,
[
ch
],
initializer
=
gamma_initializer
)
gamma
=
tf
.
reshape
(
gamma
,
new_shape
)
gamma
=
tf
.
reshape
(
gamma
,
new_shape
)
else
:
gamma
=
tf
.
ones
([
1
,
1
,
1
,
1
],
name
=
'gamma'
,
dtype
=
x
.
dtype
)
ret
=
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
ret
=
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
()
vh
=
ret
.
variables
=
VariableHolder
()
if
use_affin
e
:
if
scal
e
:
vh
.
gamma
=
gamma
vh
.
gamma
=
gamma
if
center
:
vh
.
beta
=
beta
vh
.
beta
=
beta
return
ret
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment