Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7bcde8ad
Commit
7bcde8ad
authored
Jul 08, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs
parent
b58b3a78
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
62 deletions
+41
-62
examples/FasterRCNN/README.md
examples/FasterRCNN/README.md
+9
-8
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+4
-3
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+6
-47
tensorpack/train/interface.py
tensorpack/train/interface.py
+15
-0
tensorpack/train/tower.py
tensorpack/train/tower.py
+7
-4
No files found.
examples/FasterRCNN/README.md
View file @
7bcde8ad
...
...
@@ -13,7 +13,8 @@ with the support of:
## Dependencies
+
Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug);
+
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
+
Pre-trained
[
ImageNet ResNet model
](
http://models.tensorpack.com/ResNet/
)
from tensorpack model zoo.
+
Pre-trained
[
ImageNet ResNet model
](
http://models.tensorpack.com/FasterRCNN/
)
from tensorpack model zoo. Use the models with "-AlignPadding".
+
COCO data. It needs to have the following directory structure:
```
COCO/DIR/
...
...
@@ -37,7 +38,7 @@ To train:
./train.py --config \
MODE_MASK=True MODE_FPN=True \
DATA.BASEDIR=/path/to/COCO/DIR \
BACKBONE.WEIGHTS=/path/to/ImageNet-R
esNet50
.npz \
BACKBONE.WEIGHTS=/path/to/ImageNet-R
50-Pad
.npz \
```
Options can be changed by either the command line or the
`config.py`
file.
Recommended configurations are listed in the table below.
...
...
@@ -50,13 +51,13 @@ To predict on an image (and show output in a window):
./train.py --predict input.jpg --load /path/to/model --config SAME-AS-TRAINING
```
Evaluate the performance of a model on COCO
, and save results to json
.
(
Trained COCO
models can be downloaded in
[
model zoo
](
http://models.tensorpack.com/FasterRCNN
)
:
Evaluate the performance of a model on COCO.
(
Several trained
models can be downloaded in
[
model zoo
](
http://models.tensorpack.com/FasterRCNN
)
:
```
./train.py --evaluate output.json --load /path/to/COCO-R
esNet50-MaskRCNN
.npz \
./train.py --evaluate output.json --load /path/to/COCO-R
50C4-MaskRCNN-Standard
.npz \
--config MODE_MASK=True DATA.BASEDIR=/path/to/COCO/DIR
```
Evaluation or prediction will need the same
config
used during training.
Evaluation or prediction will need the same
`--config`
used during training.
## Results
...
...
@@ -69,8 +70,8 @@ MaskRCNN results contain both box and mask mAP.
| R50-C4 | 36.6 | 36.5 | 44h on 8 V100s |
<details><summary>
standard
</summary>
`MODE_MASK=False`
</details>
|
| R50-FPN | 37.5 | 37.9
<sup>
[
1
](
#ft1
)
</sup>
| 28h on 8 V100s |
<details><summary>
standard
</summary>
`MODE_MASK=False MODE_FPN=True`
</details>
|
| R50-C4 | 36.8/32.1 | | 39h on 8 P100s |
<details><summary>
quick
</summary>
`MODE_MASK=True FRCNN.BATCH_PER_IM=256`
<br/>
`TRAIN.LR_SCHEDULE=[150000,230000,280000]`
</details>
|
| R50-C4 | 37.8/33.1 | 37.8/32.8 | 4
5
h on 8 V100s |
<details><summary>
standard
</summary>
`MODE_MASK=True`
</details>
|
| R50-FPN | 38.
1
/34.9 | 38.6/34.5
<sup>
[
1
](
#ft1
)
</sup>
| 32h on 8 V100s |
<details><summary>
standard
</summary>
`MODE_MASK=True MODE_FPN=True`
</details>
|
| R50-C4 | 37.8/33.1 | 37.8/32.8 | 4
9
h on 8 V100s |
<details><summary>
standard
</summary>
`MODE_MASK=True`
</details>
|
| R50-FPN | 38.
2
/34.9 | 38.6/34.5
<sup>
[
1
](
#ft1
)
</sup>
| 32h on 8 V100s |
<details><summary>
standard
</summary>
`MODE_MASK=True MODE_FPN=True`
</details>
|
| R50-FPN | 38.5/34.8 | 38.6/34.2
<sup>
[
2
](
#ft2
)
</sup>
| 34h on 8 V100s |
<details><summary>
standard+ConvHead
</summary>
`MODE_MASK=True MODE_FPN=True`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_head`
</details>
|
| R50-FPN | 39.5/35.2 | 39.5/34.4
<sup>
[
2
](
#ft2
)
</sup>
| 34h on 8 V100s |
<details><summary>
standard+ConvGNHead
</summary>
`MODE_MASK=True MODE_FPN=True`
<br/>
`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`
</details>
|
| R101-C4 | 40.8/35.1 | | 63h on 8 V100s |
<details><summary>
standard
</summary>
`MODE_MASK=True `
<br/>
`BACKBONE.RESNET_NUM_BLOCK=[3,4,23,3]`
</details>
|
...
...
examples/FasterRCNN/config.py
View file @
7bcde8ad
...
...
@@ -68,10 +68,11 @@ _C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN
# Use a base model with TF-preferred padding mode,
# which may pad more pixels on right/bottom than top/left.
# TF_PAD_MODE=False is better for accuracy but will require a different base model.
# We will eventually switch to TF_PAD_MODE=False.
# See https://github.com/tensorflow/tensorflow/issues/18213
_C
.
BACKBONE
.
TF_PAD_MODE
=
True
# In tensorpack model zoo, ResNet models with TF_PAD_MODE=False are marked with "-AlignPadding".
# All other models under `ResNet/` in the model zoo are trained with TF_PAD_MODE=True.
_C
.
BACKBONE
.
TF_PAD_MODE
=
False
_C
.
BACKBONE
.
STRIDE_1X1
=
False
# True for MSRA models
# schedule -----------------------
...
...
tensorpack/graph_builder/model_desc.py
View file @
7bcde8ad
...
...
@@ -8,7 +8,6 @@ import tensorflow as tf
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.develop
import
log_deprecated
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.tower
import
get_current_tower_context
from
..input_source
import
InputSource
from
..models.regularize
import
regularize_cost_from_collection
...
...
@@ -128,17 +127,16 @@ class ModelDescBase(object):
"""
Build the whole symbolic graph.
This is supposed to be part of the "tower function" when used with :class:`TowerTrainer`.
By default it will call :meth:`_build_graph` with a list of input tensors.
A subclass is expected to overwrite this method
or the :meth:`_build_graph` method
.
A subclass is expected to overwrite this method.
Args:
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
Returns:
In general it returns nothing, but a subclass
(e.g.
:class:`ModelDesc`) may require it to return necessary information
(e.g. cost) to build the traine
r.
In general it returns nothing, but a subclass
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tenso
r.
"""
if
len
(
args
)
==
1
:
arg
=
args
[
0
]
...
...
@@ -230,51 +228,12 @@ class ModelDesc(ModelDescBase):
def
_build_graph_get_cost
(
self
,
*
inputs
):
"""
Used internally by trainers to get the final cost for optimization.
Used internally by trainers to get the final cost for optimization
in a backward-compatible way
.
"""
ret
=
self
.
build_graph
(
*
inputs
)
if
not
get_current_tower_context
()
.
is_training
:
return
None
# this is the tower function, could be called for inference
if
isinstance
(
ret
,
tf
.
Tensor
):
# the preferred way
assert
ret
.
shape
.
ndims
==
0
,
"Cost must be a scalar, but found a tensor of shape {}!"
.
format
(
ret
.
shape
)
_check_unused_regularization
()
return
ret
else
:
# the old way
else
:
# the old way
, for compatibility
return
self
.
get_cost
()
# TODO this is deprecated and only used for v1 trainers
def
_build_graph_get_grads
(
self
,
*
inputs
):
"""
Build the graph from inputs and return the grads.
Returns:
[(grad, var)]
"""
ctx
=
get_current_tower_context
()
cost
=
self
.
_build_graph_get_cost
(
*
inputs
)
if
not
ctx
.
is_training
:
return
None
# this is the tower function, could be called for inference
if
ctx
.
has_own_variables
:
varlist
=
ctx
.
get_collection_in_tower
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
else
:
varlist
=
tf
.
trainable_variables
()
opt
=
self
.
get_optimizer
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
def
_check_unused_regularization
():
coll
=
tf
.
get_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
unconsumed_reg
=
[]
for
c
in
coll
:
if
len
(
c
.
consumers
())
==
0
:
unconsumed_reg
.
append
(
c
)
if
unconsumed_reg
:
logger
.
warn
(
"The following tensors appear in REGULARIZATION_LOSSES collection but has no "
"consumers! You may have forgotten to add regularization to total cost."
)
logger
.
warn
(
"Unconsumed regularization: {}"
.
format
(
', '
.
join
([
x
.
name
for
x
in
unconsumed_reg
])))
tensorpack/train/interface.py
View file @
7bcde8ad
# -*- coding: utf-8 -*-
# File: interface.py
import
tensorflow
as
tf
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
,
StagingInput
,
DummyConstantInput
)
from
..utils
import
logger
...
...
@@ -79,6 +81,7 @@ def launch_train_with_config(config, trainer):
trainer
.
setup_graph
(
inputs_desc
,
input
,
model
.
_build_graph_get_cost
,
model
.
get_optimizer
)
_check_unused_regularization
()
trainer
.
train_with_defaults
(
callbacks
=
config
.
callbacks
,
monitors
=
config
.
monitors
,
...
...
@@ -88,3 +91,15 @@ def launch_train_with_config(config, trainer):
starting_epoch
=
config
.
starting_epoch
,
max_epoch
=
config
.
max_epoch
,
extra_callbacks
=
config
.
extra_callbacks
)
def
_check_unused_regularization
():
coll
=
tf
.
get_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
unconsumed_reg
=
[]
for
c
in
coll
:
if
len
(
c
.
consumers
())
==
0
:
unconsumed_reg
.
append
(
c
)
if
unconsumed_reg
:
logger
.
warn
(
"The following tensors appear in REGULARIZATION_LOSSES collection but have no "
"consumers! You may have forgotten to add regularization to total cost."
)
logger
.
warn
(
"Unconsumed regularization: {}"
.
format
(
', '
.
join
([
x
.
name
for
x
in
unconsumed_reg
])))
tensorpack/train/tower.py
View file @
7bcde8ad
...
...
@@ -129,14 +129,14 @@ class SingleCostTrainer(TowerTrainer):
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training
operations
from them.
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training
graph
from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
COLOCATE_GRADIENTS_WITH_OPS
=
True
"""
See `tf.gradients`.
This might
affect performance when backward op does
See `tf.gradients`.
It sometimes can heavily
affect performance when backward op does
not support the device of forward op.
"""
...
...
@@ -159,7 +159,7 @@ class SingleCostTrainer(TowerTrainer):
optimizer. Will only be called once.
Note:
`get_cost_fn` will be the tower function.
`get_cost_fn` will be
part of
the tower function.
It must follows the `rules of tower function.
<http://tensorpack.readthedocs.io/en/latest/tutorial/trainer.html#tower-trainer>`_.
"""
...
...
@@ -188,15 +188,18 @@ class SingleCostTrainer(TowerTrainer):
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Internal use only.
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert
input
.
setup_done
()
def
get_grad_fn
():
ctx
=
get_current_tower_context
()
cost
=
get_cost_fn
(
*
input
.
get_input_tensors
())
assert
isinstance
(
cost
,
tf
.
Tensor
),
cost
assert
cost
.
shape
.
ndims
==
0
,
"Cost must be a scalar, but found {}!"
.
format
(
cost
)
if
not
ctx
.
is_training
:
return
None
# this is the tower function, could be called for inference
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment