Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dda6fd53
Commit
dda6fd53
authored
Dec 19, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove some use of contrib for tf1.13
parent
ab81a75d
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
64 additions
and
35 deletions
+64
-35
.github/pull_request_template.md
.github/pull_request_template.md
+6
-8
README.md
README.md
+1
-1
docs/tutorial/intro.rst
docs/tutorial/intro.rst
+1
-1
examples/FasterRCNN/README.md
examples/FasterRCNN/README.md
+4
-4
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+2
-0
examples/FasterRCNN/model_frcnn.py
examples/FasterRCNN/model_frcnn.py
+2
-3
examples/README.md
examples/README.md
+1
-1
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+6
-5
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+5
-1
tensorpack/libinfo.py
tensorpack/libinfo.py
+1
-0
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+9
-6
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+12
-2
tensorpack/models/fc.py
tensorpack/models/fc.py
+7
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+7
-2
No files found.
.github/pull_request_template.md
View file @
dda6fd53
Thanks for your contribution!
Thanks for your contribution!
Unless you want to send a simple several lines of PR that can be easily merged, please note the following:
Unless you want to send a simple several lines of PR that can be easily merged, please note the following:
*
If you want to add a new feature,
please open an issue first and indicate that you want to
*
If you want to add a new feature,
contribute.
please open an issue first and indicate that you want to
contribute.
There are features that we prefer to not add to tensorpack, e.g. symbolic models
There are features that we prefer to not add to tensorpack, e.g. symbolic models
(see details at https://tensorpack.readthedocs.io/tutorial/symbolic.html).
(see details at https://tensorpack.readthedocs.io/tutorial/symbolic.html).
Therefore it's good to have a discussion first.
Therefore it's good to have a discussion first.
*
If you want to add a new example, please note that:
*
If you want to add a new example, please note that:
1. We prefer to not have an example that is too similar to existing ones in terms of the tasks.
1.
We prefer to not have an example that is too similar to existing ones in terms of the tasks.
2. Examples have to be able to reproduce (preferrably in some mesurable metrics) published or well-known experiments and results.
2.
Examples have to be able to reproduce (preferrably in some mesurable metrics) published or well-known experiments and results.
*
Please run
`flake8 .`
under the root of this repo to lint your code, and make sure the command produces no output.
*
Please run
`flake8 .`
under the root of this repo to lint your code, and make sure the command produces no output.
README.md
View file @
dda6fd53
...
@@ -41,7 +41,7 @@ demonstrating its __flexibility__ for actual research.
...
@@ -41,7 +41,7 @@ demonstrating its __flexibility__ for actual research.
### Vision:
### Vision:
+
[
Train ResNet
](
examples/ResNet
)
and
[
other models
](
examples/ImageNetModels
)
on ImageNet.
+
[
Train ResNet
](
examples/ResNet
)
and
[
other models
](
examples/ImageNetModels
)
on ImageNet.
+
[
Train
Faster-RCNN / Mask-R
CNN on COCO object detection
](
examples/FasterRCNN
)
+
[
Train
Mask/Faster R-
CNN on COCO object detection
](
examples/FasterRCNN
)
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+
[
DoReFa-Net: train binary / low-bitwidth CNN on ImageNet
](
examples/DoReFa-Net
)
+
[
DoReFa-Net: train binary / low-bitwidth CNN on ImageNet
](
examples/DoReFa-Net
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
examples/HED
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
examples/HED
)
...
...
docs/tutorial/intro.rst
View file @
dda6fd53
...
@@ -22,7 +22,7 @@ No it's not, but it's not easy to write it in an efficient way.
...
@@ -22,7 +22,7 @@ No it's not, but it's not easy to write it in an efficient way.
When **speed** is a concern, users will have to worry a lot about things unrelated to the model.
When **speed** is a concern, users will have to worry a lot about things unrelated to the model.
Code written with low-level APIs or other existing high-level wrappers is often suboptimal in speed.
Code written with low-level APIs or other existing high-level wrappers is often suboptimal in speed.
Even most of the official TensorFlow examples are written for simplicity rather than efficiency,
Even most of the official TensorFlow examples are written for simplicity rather than efficiency,
which as a result makes people think TensorFlow is
__slow__
.
which as a result makes people think TensorFlow is
*slow*
.
The `official TensorFlow benchmark <https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks>`_ said this in their README:
The `official TensorFlow benchmark <https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks>`_ said this in their README:
...
...
examples/FasterRCNN/README.md
View file @
dda6fd53
...
@@ -87,11 +87,11 @@ FPN models are sometimes slightly worse, which is mainly due to batch size.
...
@@ -87,11 +87,11 @@ FPN models are sometimes slightly worse, which is mainly due to batch size.
| Backbone | mAP
<br/>
(box;mask) | Detectron mAP
<sup>
[
1
](
#ft1
)
</sup><br/>
(box;mask) | Time (on 8 V100s) | Configurations
<br/>
(click to expand) |
| Backbone | mAP
<br/>
(box;mask) | Detectron mAP
<sup>
[
1
](
#ft1
)
</sup><br/>
(box;mask) | Time (on 8 V100s) | Configurations
<br/>
(click to expand) |
| - | - | - | - | - |
| - | - | - | - | - |
| R50-C4 | 33.1 | | 18h |
<details><summary>
super quick
</summary>
`MODE_MASK=False FRCNN.BATCH_PER_IM=64`
<br/>
`PREPROC.SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`
<br/>
`TRAIN.LR_SCHEDULE=[150000,230000,280000]`
</details>
|
| R50-C4 | 33.1 | | 18h |
<details><summary>
super quick
</summary>
`MODE_MASK=False FRCNN.BATCH_PER_IM=64`
<br/>
`PREPROC.
TRAIN_
SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`
<br/>
`TRAIN.LR_SCHEDULE=[150000,230000,280000]`
</details>
|
| R50-C4 | 36.6 | 36.5 | 44h |
<details><summary>
standard
</summary>
`MODE_MASK=False`
</details>
|
| R50-C4 | 36.6 | 36.5 | 44h |
<details><summary>
standard
</summary>
`MODE_MASK=False`
</details>
|
| R50-FPN | 37.4 | 37.9 | 2
7
h |
<details><summary>
standard
</summary>
`MODE_MASK=False MODE_FPN=True`
</details>
|
| R50-FPN | 37.4 | 37.9 | 2
3
h |
<details><summary>
standard
</summary>
`MODE_MASK=False MODE_FPN=True`
</details>
|
| R50-C4 | 38.2;33.3
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50C4-MaskRCNN-Standard.npz
)
| 37.8;32.8 | 4
8
h |
<details><summary>
standard
</summary>
this is the default
</details>
|
| R50-C4 | 38.2;33.3
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50C4-MaskRCNN-Standard.npz
)
| 37.8;32.8 | 4
9
h |
<details><summary>
standard
</summary>
this is the default
</details>
|
| R50-FPN | 38.4;35.1
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-Standard.npz
)
| 38.6;34.5 | 2
8
h |
<details><summary>
standard
</summary>
`MODE_FPN=True`
</details>
|
| R50-FPN | 38.4;35.1
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-Standard.npz
)
| 38.6;34.5 | 2
7
h |
<details><summary>
standard
</summary>
`MODE_FPN=True`
</details>
|
| R50-FPN | 42.0;36.3 | | 41h |
<details><summary>
+Cascade
</summary>
`MODE_FPN=True FPN.CASCADE=True`
</details>
|
| R50-FPN | 42.0;36.3 | | 41h |
<details><summary>
+Cascade
</summary>
`MODE_FPN=True FPN.CASCADE=True`
</details>
|
| R50-FPN | 39.5;35.2 | 39.5;34.4
<sup>
[
2
](
#ft2
)
</sup>
| 33h |
<details><summary>
+ConvGNHead
</summary>
`MODE_FPN=True`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
</details>
|
| R50-FPN | 39.5;35.2 | 39.5;34.4
<sup>
[
2
](
#ft2
)
</sup>
| 33h |
<details><summary>
+ConvGNHead
</summary>
`MODE_FPN=True`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
</details>
|
| R50-FPN | 40.0;36.2
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-StandardGN.npz
)
| 40.3;35.7 | 40h |
<details><summary>
+GN
</summary>
`MODE_FPN=True`
<br/>
`FPN.NORM=GN BACKBONE.NORM=GN`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
<br/>
`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head`
|
| R50-FPN | 40.0;36.2
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-StandardGN.npz
)
| 40.3;35.7 | 40h |
<details><summary>
+GN
</summary>
`MODE_FPN=True`
<br/>
`FPN.NORM=GN BACKBONE.NORM=GN`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
<br/>
`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head`
|
...
...
examples/FasterRCNN/config.py
View file @
dda6fd53
...
@@ -234,6 +234,8 @@ def finalize_configs(is_training):
...
@@ -234,6 +234,8 @@ def finalize_configs(is_training):
if
is_training
:
if
is_training
:
train_scales
=
_C
.
PREPROC
.
TRAIN_SHORT_EDGE_SIZE
train_scales
=
_C
.
PREPROC
.
TRAIN_SHORT_EDGE_SIZE
if
not
isinstance
(
train_scales
,
(
list
,
tuple
)):
train_scales
=
[
train_scales
,
train_scales
]
if
train_scales
[
1
]
-
train_scales
[
0
]
>
100
:
if
train_scales
[
1
]
-
train_scales
[
0
]
>
100
:
# don't warmup if augmentation is on
# don't warmup if augmentation is on
os
.
environ
[
'TF_CUDNN_USE_AUTOTUNE'
]
=
'0'
os
.
environ
[
'TF_CUDNN_USE_AUTOTUNE'
]
=
'0'
...
...
examples/FasterRCNN/model_frcnn.py
View file @
dda6fd53
...
@@ -209,17 +209,16 @@ def fastrcnn_predictions(boxes, scores):
...
@@ -209,17 +209,16 @@ def fastrcnn_predictions(boxes, scores):
selection
=
tf
.
image
.
non_max_suppression
(
selection
=
tf
.
image
.
non_max_suppression
(
box
,
prob
,
cfg
.
TEST
.
RESULTS_PER_IM
,
cfg
.
TEST
.
FRCNN_NMS_THRESH
)
box
,
prob
,
cfg
.
TEST
.
RESULTS_PER_IM
,
cfg
.
TEST
.
FRCNN_NMS_THRESH
)
selection
=
tf
.
gather
(
ids
,
selection
)
selection
=
tf
.
gather
(
ids
,
selection
)
# sort available in TF>1.4.0
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection
=
-
tf
.
nn
.
top_k
(
-
selection
,
k
=
tf
.
size
(
selection
))[
0
]
if
get_tf_version_tuple
()
>=
(
1
,
13
):
if
get_tf_version_tuple
()
>=
(
1
,
13
):
sorted_selection
=
tf
.
sort
(
selection
,
direction
=
'ASCENDING'
)
mask
=
tf
.
sparse
.
SparseTensor
(
indices
=
tf
.
expand_dims
(
sorted_selection
,
1
),
mask
=
tf
.
sparse
.
SparseTensor
(
indices
=
tf
.
expand_dims
(
sorted_selection
,
1
),
values
=
tf
.
ones_like
(
sorted_selection
,
dtype
=
tf
.
bool
),
values
=
tf
.
ones_like
(
sorted_selection
,
dtype
=
tf
.
bool
),
dense_shape
=
output_shape
)
dense_shape
=
output_shape
)
mask
=
tf
.
sparse
.
to_dense
(
mask
,
default_value
=
False
)
mask
=
tf
.
sparse
.
to_dense
(
mask
,
default_value
=
False
)
else
:
else
:
# this function is deprecated by TF
# this function is deprecated by TF
sorted_selection
=
-
tf
.
nn
.
top_k
(
-
selection
,
k
=
tf
.
size
(
selection
))[
0
]
mask
=
tf
.
sparse_to_dense
(
mask
=
tf
.
sparse_to_dense
(
sparse_indices
=
sorted_selection
,
sparse_indices
=
sorted_selection
,
output_shape
=
output_shape
,
output_shape
=
output_shape
,
...
...
examples/README.md
View file @
dda6fd53
...
@@ -27,7 +27,7 @@ These are all the toy examples in tensorpack. They are supposed to be just demos
...
@@ -27,7 +27,7 @@ These are all the toy examples in tensorpack. They are supposed to be just demos
| Name | Performance |
| Name | Performance |
| --- | --- |
| --- | --- |
| Train
[
ResNet
](
ResNet
)
,
[
ShuffleNet and other models
](
ImageNetModels
)
on ImageNet | reproduce paper |
| Train
[
ResNet
](
ResNet
)
,
[
ShuffleNet and other models
](
ImageNetModels
)
on ImageNet | reproduce paper |
|
[
Train
Faster-RCNN / Mask-R
CNN on COCO
](
FasterRCNN
)
| reproduce paper |
|
[
Train
Mask/Faster R-
CNN on COCO
](
FasterRCNN
)
| reproduce paper |
|
[
Generative Adversarial Network(GAN) variants
](
GAN
)
, including DCGAN, InfoGAN,
<br/>
Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN | visually reproduce |
|
[
Generative Adversarial Network(GAN) variants
](
GAN
)
, including DCGAN, InfoGAN,
<br/>
Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN | visually reproduce |
|
[
DoReFa-Net: training binary / low-bitwidth CNN on ImageNet
](
DoReFa-Net
)
| reproduce paper |
|
[
DoReFa-Net: training binary / low-bitwidth CNN on ImageNet
](
DoReFa-Net
)
| reproduce paper |
|
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
HED
)
| visually reproduce |
|
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
HED
)
| visually reproduce |
...
...
tensorpack/graph_builder/training.py
View file @
dda6fd53
...
@@ -314,12 +314,13 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
...
@@ -314,12 +314,13 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
post_init_ops
=
[]
post_init_ops
=
[]
def
log_failure
(
name
,
reason
):
def
log_failure
(
name
,
reason
):
if
name
in
trainable_names
:
msg
=
"This variable is trainable, so this is probably a fatal error."
else
:
msg
=
"This variable is non-trainable. Ignore this warning if you know it's OK to leave it out-of-sync."
logger
.
warn
(
"[ReplicatedTrainer] Do not know how to sync variable '{}' across GPUs. "
logger
.
warn
(
"[ReplicatedTrainer] Do not know how to sync variable '{}' across GPUs. "
"Reason: {} "
.
format
(
name
,
reason
)
+
msg
)
"Reason: {} "
.
format
(
name
,
reason
))
assert
name
not
in
trainable_names
,
\
"The aforementioned variable is trainable, so this is probably a fatal error."
logger
.
warn
(
"[ReplicatedTrainer] This variable is non-trainable. "
"Ignore this warning if you know it's OK to leave it out-of-sync."
)
for
v
in
all_vars
:
for
v
in
all_vars
:
if
not
v
.
name
.
startswith
(
'tower'
):
if
not
v
.
name
.
startswith
(
'tower'
):
...
...
tensorpack/graph_builder/utils.py
View file @
dda6fd53
...
@@ -148,7 +148,11 @@ def allreduce_grads(all_grads, average):
...
@@ -148,7 +148,11 @@ def allreduce_grads(all_grads, average):
Returns:
Returns:
K x N: same as input, but each grad is replaced by the average over K devices.
K x N: same as input, but each grad is replaced by the average over K devices.
"""
"""
from
tensorflow.contrib
import
nccl
if
get_tf_version_tuple
()
<=
(
1
,
12
):
from
tensorflow.contrib
import
nccl
else
:
from
tensorflow.python.ops
import
nccl_ops
as
nccl
nr_tower
=
len
(
all_grads
)
nr_tower
=
len
(
all_grads
)
if
nr_tower
==
1
:
if
nr_tower
==
1
:
return
all_grads
return
all_grads
...
...
tensorpack/libinfo.py
View file @
dda6fd53
...
@@ -54,6 +54,7 @@ try:
...
@@ -54,6 +54,7 @@ try:
assert
int
(
_version
[
0
])
>=
1
and
int
(
_version
[
1
])
>=
3
,
"TF>=1.3 is required!"
assert
int
(
_version
[
0
])
>=
1
and
int
(
_version
[
1
])
>=
3
,
"TF>=1.3 is required!"
_HAS_TF
=
True
_HAS_TF
=
True
except
ImportError
:
except
ImportError
:
print
(
"Failed to import tensorflow."
)
_HAS_TF
=
False
_HAS_TF
=
False
...
...
tensorpack/models/batch_norm.py
View file @
dda6fd53
...
@@ -230,13 +230,16 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -230,13 +230,16 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
"Cross-GPU BatchNorm is only supported in TF>=1.10 ."
\
"Cross-GPU BatchNorm is only supported in TF>=1.10 ."
\
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
try
:
if
TF_version
<=
(
1
,
12
):
from
tensorflow.contrib.nccl.python.ops.nccl_ops
import
_validate_and_load_nccl_so
try
:
except
Exception
:
from
tensorflow.contrib.nccl.python.ops.nccl_ops
import
_validate_and_load_nccl_so
pass
except
Exception
:
pass
else
:
_validate_and_load_nccl_so
()
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
else
:
else
:
_validate_and_load_nccl_so
()
from
tensorflow.python.ops
import
gen_nccl_ops
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
shared_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
tf
.
get_variable_scope
()
.
name
)
shared_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
tf
.
get_variable_scope
()
.
name
)
batch_mean
=
gen_nccl_ops
.
nccl_all_reduce
(
batch_mean
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean
,
input
=
batch_mean
,
...
...
tensorpack/models/conv2d.py
View file @
dda6fd53
...
@@ -29,7 +29,7 @@ def Conv2D(
...
@@ -29,7 +29,7 @@ def Conv2D(
dilation_rate
=
(
1
,
1
),
dilation_rate
=
(
1
,
1
),
activation
=
None
,
activation
=
None
,
use_bias
=
True
,
use_bias
=
True
,
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
,
kernel_initializer
=
None
,
bias_initializer
=
tf
.
zeros_initializer
(),
bias_initializer
=
tf
.
zeros_initializer
(),
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
bias_regularizer
=
None
,
...
@@ -48,6 +48,11 @@ def Conv2D(
...
@@ -48,6 +48,11 @@ def Conv2D(
* ``W``: weights
* ``W``: weights
* ``b``: bias
* ``b``: bias
"""
"""
if
kernel_initializer
is
None
:
if
get_tf_version_tuple
()
<=
(
1
,
12
):
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
),
else
:
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
(
2.0
)
if
split
==
1
:
if
split
==
1
:
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
layer
=
tf
.
layers
.
Conv2D
(
layer
=
tf
.
layers
.
Conv2D
(
...
@@ -134,7 +139,7 @@ def Conv2DTranspose(
...
@@ -134,7 +139,7 @@ def Conv2DTranspose(
data_format
=
'channels_last'
,
data_format
=
'channels_last'
,
activation
=
None
,
activation
=
None
,
use_bias
=
True
,
use_bias
=
True
,
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
,
kernel_initializer
=
None
,
bias_initializer
=
tf
.
zeros_initializer
(),
bias_initializer
=
tf
.
zeros_initializer
(),
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
bias_regularizer
=
None
,
...
@@ -151,6 +156,11 @@ def Conv2DTranspose(
...
@@ -151,6 +156,11 @@ def Conv2DTranspose(
* ``W``: weights
* ``W``: weights
* ``b``: bias
* ``b``: bias
"""
"""
if
kernel_initializer
is
None
:
if
get_tf_version_tuple
()
<=
(
1
,
12
):
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
),
else
:
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
(
2.0
)
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
layer
=
tf
.
layers
.
Conv2DTranspose
(
layer
=
tf
.
layers
.
Conv2DTranspose
(
...
...
tensorpack/models/fc.py
View file @
dda6fd53
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
import
numpy
as
np
from
..tfutils.common
import
get_tf_version_tuple
from
.common
import
layer_register
,
VariableHolder
from
.common
import
layer_register
,
VariableHolder
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
...
@@ -30,7 +31,7 @@ def FullyConnected(
...
@@ -30,7 +31,7 @@ def FullyConnected(
units
,
units
,
activation
=
None
,
activation
=
None
,
use_bias
=
True
,
use_bias
=
True
,
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
,
kernel_initializer
=
None
,
bias_initializer
=
tf
.
zeros_initializer
(),
bias_initializer
=
tf
.
zeros_initializer
(),
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
bias_regularizer
=
None
,
...
@@ -45,6 +46,11 @@ def FullyConnected(
...
@@ -45,6 +46,11 @@ def FullyConnected(
* ``W``: weights of shape [in_dim, out_dim]
* ``W``: weights of shape [in_dim, out_dim]
* ``b``: bias
* ``b``: bias
"""
"""
if
kernel_initializer
is
None
:
if
get_tf_version_tuple
()
<=
(
1
,
12
):
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
),
else
:
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
(
2.0
)
inputs
=
batch_flatten
(
inputs
)
inputs
=
batch_flatten
(
inputs
)
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
with
rename_get_variable
({
'kernel'
:
'W'
,
'bias'
:
'b'
}):
...
...
tensorpack/models/regularize.py
View file @
dda6fd53
...
@@ -7,6 +7,7 @@ import re
...
@@ -7,6 +7,7 @@ import re
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
graph_memoized
from
..utils.argtools
import
graph_memoized
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
.common
import
layer_register
from
.common
import
layer_register
...
@@ -19,8 +20,12 @@ def _log_once(msg):
...
@@ -19,8 +20,12 @@ def _log_once(msg):
logger
.
info
(
msg
)
logger
.
info
(
msg
)
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
if
get_tf_version_tuple
()
<=
(
1
,
12
):
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
else
:
l2_regularizer
=
tf
.
keras
.
regularizers
.
l2
l1_regularizer
=
tf
.
keras
.
regularizers
.
l1
def
regularize_cost
(
regex
,
func
,
name
=
'regularize_cost'
):
def
regularize_cost
(
regex
,
func
,
name
=
'regularize_cost'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment