Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
85586fc5
Commit
85586fc5
authored
Jul 26, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs
parent
72674731
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
43 additions
and
32 deletions
+43
-32
examples/FasterRCNN/README.md
examples/FasterRCNN/README.md
+11
-11
examples/ImageNetModels/README.md
examples/ImageNetModels/README.md
+5
-4
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+1
-1
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+8
-2
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+16
-13
No files found.
examples/FasterRCNN/README.md
View file @
85586fc5
...
@@ -70,17 +70,17 @@ Evaluation or prediction will need the same `--config` used during training.
...
@@ -70,17 +70,17 @@ Evaluation or prediction will need the same `--config` used during training.
These models are trained with different configurations on trainval35k and evaluated on minival using mAP@IoU=0.50:0.95.
These models are trained with different configurations on trainval35k and evaluated on minival using mAP@IoU=0.50:0.95.
MaskRCNN results contain both box and mask mAP.
MaskRCNN results contain both box and mask mAP.
| Backbone | mAP
<br/>
(box;mask) | Detectron mAP
<br/>
(box;mask) | Time on 8 V100s | Configurations
<br/>
(click to expand) |
| Backbone | mAP
<br/>
(box;mask) | Detectron mAP
<br/>
(box;mask) | Time on 8 V100s | Configurations
<br/>
(click to expand)
|
| - | - | - | - | - |
| - | - | - | - | -
|
| R50-C4 | 33.1 | | 18h |
<details><summary>
super quick
</summary>
`MODE_MASK=False FRCNN.BATCH_PER_IM=64`
<br/>
`PREPROC.SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`
<br/>
`TRAIN.LR_SCHEDULE=[150000,230000,280000]`
</details>
|
| R50-C4 | 33.1 | | 18h |
<details><summary>
super quick
</summary>
`MODE_MASK=False FRCNN.BATCH_PER_IM=64`
<br/>
`PREPROC.SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`
<br/>
`TRAIN.LR_SCHEDULE=[150000,230000,280000]`
</details>
|
| R50-C4 | 36.6 | 36.5 | 44h |
<details><summary>
standard
</summary>
`MODE_MASK=False`
</details>
|
| R50-C4 | 36.6 | 36.5 | 44h |
<details><summary>
standard
</summary>
`MODE_MASK=False`
</details>
|
| R50-FPN | 37.4 | 37.9
<sup>
[
1
](
#ft1
)
</sup>
| 30h |
<details><summary>
standard
</summary>
`MODE_MASK=False MODE_FPN=True`
</details>
|
| R50-FPN | 37.4 | 37.9
<sup>
[
1
](
#ft1
)
</sup>
| 30h |
<details><summary>
standard
</summary>
`MODE_MASK=False MODE_FPN=True`
</details>
|
| R50-C4 | 37.8;33.1
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50C4-MaskRCNN-Standard.npz
)
| 37.8;32.8 | 49h |
<details><summary>
standard
</summary>
`MODE_MASK=True`
</details>
|
| R50-C4 | 37.8;33.1
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50C4-MaskRCNN-Standard.npz
)
| 37.8;32.8 | 49h |
<details><summary>
standard
</summary>
`MODE_MASK=True`
</details>
|
| R50-FPN | 38.2;34.9
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-Standard.npz
)
| 38.6;34.5
<sup>
[
1
](
#ft1
)
</sup>
| 32h |
<details><summary>
standard
</summary>
`MODE_MASK=True MODE_FPN=True`
</details>
|
| R50-FPN | 38.2;34.9
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-Standard.npz
)
| 38.6;34.5
<sup>
[
1
](
#ft1
)
</sup>
| 32h |
<details><summary>
standard
</summary>
`MODE_MASK=True MODE_FPN=True`
</details>
|
| R50-FPN | 38.5;34.8 | 38.6;34.2
<sup>
[
2
](
#ft2
)
</sup>
| 34h |
<details><summary>
standard+ConvHead
</summary>
`MODE_MASK=True MODE_FPN=True`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_head`
</details>
|
| R50-FPN | 38.5;34.8 | 38.6;34.2
<sup>
[
2
](
#ft2
)
</sup>
| 34h |
<details><summary>
standard+ConvHead
</summary>
`MODE_MASK=True MODE_FPN=True`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_head`
</details>
|
| R50-FPN | 39.5;35.2 | 39.5;34.4
<sup>
[
2
](
#ft2
)
</sup>
| 34h |
<details><summary>
standard+ConvGNHead
</summary>
`MODE_MASK=True MODE_FPN=True`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
</details>
|
| R50-FPN | 39.5;35.2 | 39.5;34.4
<sup>
[
2
](
#ft2
)
</sup>
| 34h |
<details><summary>
standard+ConvGNHead
</summary>
`MODE_MASK=True MODE_FPN=True`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
</details>
|
| R50-FPN | 40.0;36.1
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-StandardGN.npz
)
| 40.3;35.7 | 44h |
<details><summary>
standard+GN
</summary>
`MODE_MASK=True MODE_FPN=True`
<br/>
`FPN.NORM=GN BACKBONE.NORM=GN`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
<br/>
`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head |
| R50-FPN | 40.0;36.1
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-StandardGN.npz
)
| 40.3;35.7 | 44h |
<details><summary>
standard+GN
</summary>
`MODE_MASK=True MODE_FPN=True`
<br/>
`FPN.NORM=GN BACKBONE.NORM=GN`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
<br/>
`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head
`
|
| R101-C4 | 40.8;35.1 [:arrow_down:](http://models.tensorpack.com/FasterRCNN/COCO-R101C4-MaskRCNN-Standard.npz) | | 63h | <details><summary>standard</summary>`
MODE_MASK=True
`<br/>`
BACKBONE.RESNET_NUM_BLOCK=[3,4,23,3]
`
</details>
|
| R101-C4 | 40.8;35.1
[
:arrow_down:
](
http://models.tensorpack.com/FasterRCNN/COCO-R101C4-MaskRCNN-Standard.npz
)
| | 63h |
<details><summary>
standard
</summary>
`MODE_MASK=True `
<br/>
`BACKBONE.RESNET_NUM_BLOCK=[3,4,23,3]`
</details>
|
<a
id=
"ft1"
>
1
</a>
: This implementation has slightly different configurations from detectron (e.g. batch size).
<a
id=
"ft1"
>
1
</a>
: This implementation has slightly different configurations from detectron (e.g. batch size).
...
...
examples/ImageNetModels/README.md
View file @
85586fc5
...
@@ -21,9 +21,10 @@ To print flops:
...
@@ -21,9 +21,10 @@ To print flops:
```
```
It will print about 75Mflops, because the paper counts multiply+add as 1 flop.
It will print about 75Mflops, because the paper counts multiply+add as 1 flop.
Evaluate the
[
pretrained model
](
http://models.tensorpack.com/ShuffleNet/
)
:
Download and evaluate the pretrained model
:
```
```
./shufflenet.py --eval --data /path/to/ilsvrc --load /path/to/model
wget http://models.tensorpack.com/ImageNetModels/ShuffleNet.npz
./shufflenet.py --eval --data /path/to/ilsvrc --load ShuffleNet.npz
```
```
### AlexNet
### AlexNet
...
@@ -50,8 +51,8 @@ See `./vgg16.py --help` for usage.
...
@@ -50,8 +51,8 @@ See `./vgg16.py --help` for usage.
|:------------------------------------------|---------------------|--------------------:|
|:------------------------------------------|---------------------|--------------------:|
| 29~30% (large variation with random seed) | 28% | 27.6% |
| 29~30% (large variation with random seed) | 28% | 27.6% |
Note that the purpose of this experiment in the paper is not to claim GroupNorm is better
Note that the purpose of this experiment in the paper is not to claim GroupNorm is better
than BatchNorm, therefore the training settings and hyperpameters have not been individually tuned for best accuracy.
than BatchNorm, therefore the training settings and hyperpameters have not been individually tuned for best accuracy.
### ResNet
### ResNet
...
...
examples/ResNet/imagenet-resnet.py
View file @
85586fc5
...
@@ -112,7 +112,7 @@ if __name__ == '__main__':
...
@@ -112,7 +112,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--data'
,
help
=
'ILSVRC dataset dir'
)
parser
.
add_argument
(
'--data'
,
help
=
'ILSVRC dataset dir'
)
parser
.
add_argument
(
'--load'
,
help
=
'load a model for training or evaluation'
)
parser
.
add_argument
(
'--load'
,
help
=
'load a model for training or evaluation'
)
parser
.
add_argument
(
'--fake'
,
help
=
'use FakeData to debug or benchmark this model'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--fake'
,
help
=
'use FakeData to debug or benchmark this model'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--data
_
format'
,
help
=
'image data format'
,
parser
.
add_argument
(
'--data
-
format'
,
help
=
'image data format'
,
default
=
'NCHW'
,
choices
=
[
'NCHW'
,
'NHWC'
])
default
=
'NCHW'
,
choices
=
[
'NCHW'
,
'NHWC'
])
parser
.
add_argument
(
'-d'
,
'--depth'
,
help
=
'ResNet depth'
,
parser
.
add_argument
(
'-d'
,
'--depth'
,
help
=
'ResNet depth'
,
type
=
int
,
default
=
50
,
choices
=
[
18
,
34
,
50
,
101
,
152
])
type
=
int
,
default
=
50
,
choices
=
[
18
,
34
,
50
,
101
,
152
])
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
85586fc5
...
@@ -25,7 +25,7 @@ class ILSVRCMeta(object):
...
@@ -25,7 +25,7 @@ class ILSVRCMeta(object):
def
__init__
(
self
,
dir
=
None
):
def
__init__
(
self
,
dir
=
None
):
if
dir
is
None
:
if
dir
is
None
:
dir
=
get_dataset_path
(
'ilsvrc_metadata'
)
dir
=
get_dataset_path
(
'ilsvrc_metadata'
)
self
.
dir
=
dir
self
.
dir
=
os
.
path
.
expanduser
(
dir
)
mkdir_p
(
self
.
dir
)
mkdir_p
(
self
.
dir
)
f
=
os
.
path
.
join
(
self
.
dir
,
'synsets.txt'
)
f
=
os
.
path
.
join
(
self
.
dir
,
'synsets.txt'
)
if
not
os
.
path
.
isfile
(
f
):
if
not
os
.
path
.
isfile
(
f
):
...
@@ -141,6 +141,7 @@ class ILSVRC12Files(RNGDataFlow):
...
@@ -141,6 +141,7 @@ class ILSVRC12Files(RNGDataFlow):
Same as in :class:`ILSVRC12`.
Same as in :class:`ILSVRC12`.
"""
"""
assert
name
in
[
'train'
,
'test'
,
'val'
],
name
assert
name
in
[
'train'
,
'test'
,
'val'
],
name
dir
=
os
.
path
.
expanduser
(
dir
)
assert
os
.
path
.
isdir
(
dir
),
dir
assert
os
.
path
.
isdir
(
dir
),
dir
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
self
.
name
=
name
self
.
name
=
name
...
...
tensorpack/models/batch_norm.py
View file @
85586fc5
...
@@ -242,9 +242,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -242,9 +242,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
shared_name
=
shared_name
+
'_NCCL_mean_square'
)
*
(
1.0
/
num_dev
)
shared_name
=
shared_name
+
'_NCCL_mean_square'
)
*
(
1.0
/
num_dev
)
elif
sync_statistics
==
'horovod'
:
elif
sync_statistics
==
'horovod'
:
# Require https://github.com/uber/horovod/pull/331
# Require https://github.com/uber/horovod/pull/331
import
horovod
hvd_version
=
tuple
(
map
(
int
,
horovod
.
__version__
.
split
(
'.'
)))
assert
hvd_version
>=
(
0
,
13
,
6
),
"sync_statistics needs horovod>=0.13.6 !"
import
horovod.tensorflow
as
hvd
import
horovod.tensorflow
as
hvd
batch_mean
=
hvd
.
allreduce
(
batch_mean
,
average
=
True
)
if
hvd
.
size
()
==
1
:
batch_mean_square
=
hvd
.
allreduce
(
batch_mean_square
,
average
=
True
)
logger
.
warn
(
"BatchNorm(sync_statistics='horovod') is used with only one process!"
)
else
:
batch_mean
=
hvd
.
allreduce
(
batch_mean
,
average
=
True
)
batch_mean_square
=
hvd
.
allreduce
(
batch_mean_square
,
average
=
True
)
batch_var
=
batch_mean_square
-
tf
.
square
(
batch_mean
)
batch_var
=
batch_mean_square
-
tf
.
square
(
batch_mean
)
batch_mean_vec
=
batch_mean
batch_mean_vec
=
batch_mean
batch_var_vec
=
batch_var
batch_var_vec
=
batch_var
...
...
tensorpack/tfutils/gradproc.py
View file @
85586fc5
...
@@ -101,13 +101,17 @@ class MapGradient(GradientProcessor):
...
@@ -101,13 +101,17 @@ class MapGradient(GradientProcessor):
"""
"""
Apply a function on all gradient if the name matches regex.
Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged.
Keep the other gradients unchanged.
It can be used for gradient clipping, etc.
"""
"""
def
__init__
(
self
,
func
,
regex
=
'.*'
):
def
__init__
(
self
,
func
,
regex
=
'.*'
):
"""
"""
Args:
Args:
func: takes a grad or (grad, var) pair and returns a grad. If return None, the
func: a user-supplied function which takes one or two arguments.
gradient is discarded (hence no update to the variable will happen).
The argument(s) can be either a `grad` tensor, or `grad` and `var`.
The function should return the new gradient to be used.
If it return None, the gradient is discarded (hence no update to the variable will happen).
regex (str): used to match variables. Defaults to match all variables.
regex (str): used to match variables. Defaults to match all variables.
"""
"""
args
=
inspect
.
getargspec
(
func
)
.
args
args
=
inspect
.
getargspec
(
func
)
.
args
...
@@ -196,15 +200,14 @@ class PrintGradient(MapGradient):
...
@@ -196,15 +200,14 @@ class PrintGradient(MapGradient):
class
CheckGradient
(
MapGradient
):
class
CheckGradient
(
MapGradient
):
"""
"""
Check for numeric issue.
Run :func:`tf.check_numerics` for each gradient.
See :func:`tf.check_numerics` for more information.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
CheckGradient
,
self
)
.
__init__
(
self
.
_mapper
)
super
(
CheckGradient
,
self
)
.
__init__
(
self
.
_mapper
)
def
_mapper
(
self
,
grad
,
var
):
def
_mapper
(
self
,
grad
,
var
):
# this
i
s very slow.... see #3649
# this
wa
s very slow.... see #3649
# op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
# op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient/'
+
var
.
op
.
name
)
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient/'
+
var
.
op
.
name
)
return
grad
return
grad
...
@@ -215,26 +218,26 @@ class ScaleGradient(MapGradient):
...
@@ -215,26 +218,26 @@ class ScaleGradient(MapGradient):
Scale certain gradient by a multiplier.
Scale certain gradient by a multiplier.
"""
"""
def
__init__
(
self
,
multipliers
,
verbose
=
True
,
log
=
None
):
def
__init__
(
self
,
multipliers
,
verbose
=
True
):
"""
"""
Args:
Args:
multipliers (tuple or list): tuple of (regex, float), or list of tuples.
multipliers (tuple or list): tuple of (regex, float), or list of
such
tuples.
verbose (bool): whether to print logs or not
verbose (bool): whether to print logs or not
log: deprecated
Example:
Example:
Use double learning rate for all the bias (as in caffe):
Use double learning rate for all the bias (as in caffe)
, and freeze layer0
:
.. code-block:: python
.. code-block:: python
ScaleGradient(('.*/b', 2))
from tensorpack.tfutils import optimizer, gradproc
opt = optimizer.apply_grad_processors(
opt, [gradproc.ScaleGradient(
[('.*/b', 2.), ('layer0/.*', 0.)]
)])
"""
"""
if
not
isinstance
(
multipliers
,
list
):
if
not
isinstance
(
multipliers
,
list
):
multipliers
=
[
multipliers
]
multipliers
=
[
multipliers
]
self
.
multipliers
=
multipliers
self
.
multipliers
=
multipliers
if
log
is
not
None
:
logger
.
warn
(
"'log' in ScaleGradient(..) is renamed to 'verbose'."
)
verbose
=
log
assert
verbose
in
[
True
,
False
],
verbose
assert
verbose
in
[
True
,
False
],
verbose
self
.
_verbose
=
verbose
self
.
_verbose
=
verbose
super
(
ScaleGradient
,
self
)
.
__init__
(
self
.
_mapper
)
super
(
ScaleGradient
,
self
)
.
__init__
(
self
.
_mapper
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment