Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e233d835
Commit
e233d835
authored
May 22, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update fpn
parent
21d54280
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
36 deletions
+43
-36
examples/FasterRCNN/README.md
examples/FasterRCNN/README.md
+3
-3
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+11
-13
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+1
-1
examples/FasterRCNN/model.py
examples/FasterRCNN/model.py
+14
-12
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+14
-7
No files found.
examples/FasterRCNN/README.md
View file @
e233d835
# Faster-RCNN / Mask-RCNN on COCO
# Faster-RCNN / Mask-RCNN on COCO
This example aims to provide a minimal (1.3k lines) implementation of
This example aims to provide a minimal (1.3k lines) implementation of
end-to-end Faster-RCNN & Mask-RCNN (with ResNet backbones) on COCO.
end-to-end Faster-RCNN & Mask-RCNN (with ResNet
& FPN
backbones) on COCO.
## Dependencies
## Dependencies
+
Python 3; TensorFlow >= 1.4.0
+
Python 3; TensorFlow >= 1.4.0
(>=1.6.0 recommended due to a TF bug);
+
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
+
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
+
Pre-trained
[
ResNet model
](
http://models.tensorpack.com/ResNet/
)
from tensorpack model zoo.
+
Pre-trained
[
ResNet model
](
http://models.tensorpack.com/ResNet/
)
from tensorpack model zoo.
+
COCO data. It assumes the following directory structure:
+
COCO data. It assumes the following directory structure:
...
@@ -61,7 +61,7 @@ MaskRCNN results contain both bbox and segm mAP.
...
@@ -61,7 +61,7 @@ MaskRCNN results contain both bbox and segm mAP.
|R-101 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s|
|R-101 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s|
|R-101 |512 |(800, 1333)|360k |40.8/35.1 |63h on 8 V100s|
|R-101 |512 |(800, 1333)|360k |40.8/35.1 |63h on 8 V100s|
The two R-50 360k models have the same configuration __and mAP__
The two R-50 360k models have the same configuration __and mAP__
as the
`R50-C4-2x`
entries in
as the
`R50-C4-2x`
entries in
[
Detectron Model Zoo
](
https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines
)
.
[
Detectron Model Zoo
](
https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines
)
.
So far this seems to be the only open source re-implementation that can reproduce mAP in Detectron.
So far this seems to be the only open source re-implementation that can reproduce mAP in Detectron.
...
...
examples/FasterRCNN/config.py
View file @
e233d835
...
@@ -5,11 +5,11 @@ import numpy as np
...
@@ -5,11 +5,11 @@ import numpy as np
# mode flags ---------------------
# mode flags ---------------------
MODE_MASK
=
True
MODE_MASK
=
True
MODE_FPN
=
False
# dataset -----------------------
# dataset -----------------------
BASEDIR
=
'/path/to/your/COCO/DIR'
BASEDIR
=
'/path/to/your/COCO/DIR'
TRAIN_DATASET
=
[
'train2014'
,
'valminusminival2014'
]
TRAIN_DATASET
=
[
'train2014'
,
'valminusminival2014'
]
# TRAIN_DATASET = ['valminusminival2014']
VAL_DATASET
=
'minival2014'
# only support evaluation on single dataset
VAL_DATASET
=
'minival2014'
# only support evaluation on single dataset
NUM_CLASS
=
81
NUM_CLASS
=
81
CLASS_NAMES
=
[]
# NUM_CLASS strings. Will be populated later by coco loader
CLASS_NAMES
=
[]
# NUM_CLASS strings. Will be populated later by coco loader
...
@@ -29,14 +29,14 @@ LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
...
@@ -29,14 +29,14 @@ LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# image resolution --------------------
# image resolution --------------------
SHORT_EDGE_SIZE
=
800
SHORT_EDGE_SIZE
=
800
MAX_SIZE
=
1333
# TODO use 1344
MAX_SIZE
=
1333
# alternative (worse & faster) setting: 600, 1024
# alternative (worse & faster) setting: 600, 1024
# anchors -------------------------
# anchors -------------------------
ANCHOR_STRIDE
=
16
ANCHOR_STRIDE
=
16
ANCHOR_STRIDES_FPN
=
(
4
,
8
,
16
,
32
,
64
)
ANCHOR_STRIDES_FPN
=
(
4
,
8
,
16
,
32
,
64
)
# sqrtarea of the anchor box
FPN_RESOLUTION_REQUIREMENT
=
32
# image size into the backbone has to be multiple of this number
ANCHOR_SIZES
=
(
32
,
64
,
128
,
256
,
512
)
ANCHOR_SIZES
=
(
32
,
64
,
128
,
256
,
512
)
# sqrtarea of the anchor box
ANCHOR_RATIOS
=
(
0.5
,
1.
,
2.
)
ANCHOR_RATIOS
=
(
0.5
,
1.
,
2.
)
NUM_ANCHOR
=
len
(
ANCHOR_SIZES
)
*
len
(
ANCHOR_RATIOS
)
NUM_ANCHOR
=
len
(
ANCHOR_SIZES
)
*
len
(
ANCHOR_RATIOS
)
POSITIVE_ANCHOR_THRES
=
0.7
POSITIVE_ANCHOR_THRES
=
0.7
...
@@ -52,6 +52,7 @@ RPN_MIN_SIZE = 0
...
@@ -52,6 +52,7 @@ RPN_MIN_SIZE = 0
RPN_PROPOSAL_NMS_THRESH
=
0.7
RPN_PROPOSAL_NMS_THRESH
=
0.7
TRAIN_PRE_NMS_TOPK
=
12000
TRAIN_PRE_NMS_TOPK
=
12000
TRAIN_POST_NMS_TOPK
=
2000
TRAIN_POST_NMS_TOPK
=
2000
TRAIN_FPN_NMS_TOPK
=
2000
# boxes overlapping crowd will be ignored.
# boxes overlapping crowd will be ignored.
CROWD_OVERLAP_THRES
=
0.7
CROWD_OVERLAP_THRES
=
0.7
...
@@ -62,19 +63,16 @@ FASTRCNN_FG_THRESH = 0.5
...
@@ -62,19 +63,16 @@ FASTRCNN_FG_THRESH = 0.5
# fg ratio in a ROI batch
# fg ratio in a ROI batch
FASTRCNN_FG_RATIO
=
0.25
FASTRCNN_FG_RATIO
=
0.25
# modeling -------------------------
FPN_NUM_CHANNEL
=
256
FASTRCNN_FC_HEAD_DIM
=
1024
MASKRCNN_HEAD_DIM
=
256
# testing -----------------------
# testing -----------------------
TEST_PRE_NMS_TOPK
=
6000
TEST_PRE_NMS_TOPK
=
6000
TEST_POST_NMS_TOPK
=
1000
# if you encounter OOM in inference, set this to a smaller number
TEST_POST_NMS_TOPK
=
1000
# if you encounter OOM in inference, set this to a smaller number
TEST_FPN_NMS_TOPK
=
1000
FASTRCNN_NMS_THRESH
=
0.5
FASTRCNN_NMS_THRESH
=
0.5
RESULT_SCORE_THRESH
=
0.05
RESULT_SCORE_THRESH
=
0.05
RESULT_SCORE_THRESH_VIS
=
0.3
# only visualize confident results
RESULT_SCORE_THRESH_VIS
=
0.3
# only visualize confident results
RESULTS_PER_IM
=
100
RESULTS_PER_IM
=
100
# TODO Not Functioning. Don't USE
MODE_FPN
=
True
FPN_NUM_CHANNEL
=
256
MASKRCNN_HEAD_DIM
=
256
FASTRCNN_FC_HEAD_DIM
=
1024
FPN_RESOLUTION_REQUIREMENT
=
32
TRAIN_FPN_NMS_TOPK
=
2000
TEST_FPN_NMS_TOPK
=
1000
examples/FasterRCNN/data.py
View file @
e233d835
...
@@ -8,7 +8,7 @@ import itertools
...
@@ -8,7 +8,7 @@ import itertools
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.dataflow
import
(
from
tensorpack.dataflow
import
(
imgaug
,
TestDataSpeed
,
PrefetchDataZMQ
,
M
apData
,
M
ultiProcessMapDataZMQ
,
imgaug
,
TestDataSpeed
,
PrefetchDataZMQ
,
MultiProcessMapDataZMQ
,
MapDataComponent
,
DataFromList
)
MapDataComponent
,
DataFromList
)
# import tensorpack.utils.viz as tpviz
# import tensorpack.utils.viz as tpviz
...
...
examples/FasterRCNN/model.py
View file @
e233d835
...
@@ -12,7 +12,6 @@ from tensorpack.models import (
...
@@ -12,7 +12,6 @@ from tensorpack.models import (
Conv2D
,
FullyConnected
,
MaxPooling
,
Conv2D
,
FullyConnected
,
MaxPooling
,
layer_register
,
Conv2DTranspose
,
FixedUnPooling
)
layer_register
,
Conv2DTranspose
,
FixedUnPooling
)
from
tensorpack.utils
import
logger
from
utils.box_ops
import
pairwise_iou
from
utils.box_ops
import
pairwise_iou
from
utils.box_ops
import
area
as
tf_area
from
utils.box_ops
import
area
as
tf_area
import
config
import
config
...
@@ -90,7 +89,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
...
@@ -90,7 +89,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
valid_label_prob
>
th
,
valid_label_prob
>
th
,
tf
.
equal
(
valid_prediction
,
valid_anchor_labels
)),
tf
.
equal
(
valid_prediction
,
valid_anchor_labels
)),
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
placeholder
=
0.5
#
TODO
A small value will make summaries appear lower.
placeholder
=
0.5
# A small value will make summaries appear lower.
recall
=
tf
.
to_float
(
tf
.
truediv
(
pos_prediction_corr
,
nr_pos
))
recall
=
tf
.
to_float
(
tf
.
truediv
(
pos_prediction_corr
,
nr_pos
))
recall
=
tf
.
where
(
tf
.
equal
(
nr_pos
,
0
),
placeholder
,
recall
,
name
=
'recall_th{}'
.
format
(
th
))
recall
=
tf
.
where
(
tf
.
equal
(
nr_pos
,
0
),
placeholder
,
recall
,
name
=
'recall_th{}'
.
format
(
th
))
precision
=
tf
.
to_float
(
tf
.
truediv
(
pos_prediction_corr
,
nr_pos_prediction
))
precision
=
tf
.
to_float
(
tf
.
truediv
(
pos_prediction_corr
,
nr_pos_prediction
))
...
@@ -99,7 +98,9 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
...
@@ -99,7 +98,9 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
summaries
.
extend
([
precision
,
recall
])
summaries
.
extend
([
precision
,
recall
])
add_moving_summary
(
*
summaries
)
add_moving_summary
(
*
summaries
)
placeholder
=
0.
# Per-level loss summaries in FPN may appear lower. But the sum should be OK.
# Per-level loss summaries in FPN may appear lower due to the use of a small placeholder.
# But the total loss is still the same.
placeholder
=
0.
label_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
label_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
tf
.
to_float
(
valid_anchor_labels
),
logits
=
valid_label_logits
)
labels
=
tf
.
to_float
(
valid_anchor_labels
),
logits
=
valid_label_logits
)
label_loss
=
tf
.
reduce_sum
(
label_loss
)
*
(
1.
/
config
.
RPN_BATCH_PER_IM
)
label_loss
=
tf
.
reduce_sum
(
label_loss
)
*
(
1.
/
config
.
RPN_BATCH_PER_IM
)
...
@@ -601,19 +602,18 @@ def fpn_model(features):
...
@@ -601,19 +602,18 @@ def fpn_model(features):
num_channel
=
config
.
FPN_NUM_CHANNEL
num_channel
=
config
.
FPN_NUM_CHANNEL
def
upsample2x
(
name
,
x
):
def
upsample2x
(
name
,
x
):
# TODO may not be optimal in speed or math
logger
.
info
(
"Unpool 1111 ..."
)
return
FixedUnPooling
(
return
FixedUnPooling
(
name
,
x
,
2
,
unpool_mat
=
np
.
ones
((
2
,
2
),
dtype
=
'float32'
),
name
,
x
,
2
,
unpool_mat
=
np
.
ones
((
2
,
2
),
dtype
=
'float32'
),
data_format
=
'channels_first'
)
data_format
=
'channels_first'
)
with
tf
.
name_scope
(
name
):
# tf.image.resize is, again, not aligned.
logger
.
info
(
"Nearest neighbor"
)
# with tf.name_scope(name):
shape2d
=
tf
.
shape
(
x
)[
2
:]
# logger.info("Nearest neighbor")
x
=
tf
.
transpose
(
x
,
[
0
,
2
,
3
,
1
])
# shape2d = tf.shape(x)[2:]
x
=
tf
.
image
.
resize_nearest_neighbor
(
x
,
shape2d
*
2
,
align_corners
=
True
)
# x = tf.transpose(x, [0, 2, 3, 1])
x
=
tf
.
transpose
(
x
,
[
0
,
3
,
1
,
2
])
# x = tf.image.resize_nearest_neighbor(x, shape2d * 2, align_corners=True)
return
x
# x = tf.transpose(x, [0, 3, 1, 2])
# return x
with
argscope
(
Conv2D
,
data_format
=
'channels_first'
,
with
argscope
(
Conv2D
,
data_format
=
'channels_first'
,
nl
=
tf
.
identity
,
use_bias
=
True
,
nl
=
tf
.
identity
,
use_bias
=
True
,
...
@@ -636,6 +636,8 @@ def fpn_model(features):
...
@@ -636,6 +636,8 @@ def fpn_model(features):
@
under_name_scope
()
@
under_name_scope
()
def
fpn_map_rois_to_levels
(
boxes
):
def
fpn_map_rois_to_levels
(
boxes
):
"""
"""
Assign boxes to level 2~5.
Args:
Args:
boxes (nx4)
boxes (nx4)
...
...
examples/FasterRCNN/train.py
View file @
e233d835
...
@@ -19,6 +19,7 @@ from tensorpack import *
...
@@ -19,6 +19,7 @@ from tensorpack import *
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.tfutils
import
optimizer
from
tensorpack.tfutils
import
optimizer
from
tensorpack.tfutils.common
import
get_tf_version_number
import
tensorpack.utils.viz
as
tpviz
import
tensorpack.utils.viz
as
tpviz
from
tensorpack.utils.gpu
import
get_nr_gpu
from
tensorpack.utils.gpu
import
get_nr_gpu
...
@@ -33,8 +34,7 @@ from model import (
...
@@ -33,8 +34,7 @@ from model import (
generate_rpn_proposals
,
sample_fast_rcnn_targets
,
roi_align
,
generate_rpn_proposals
,
sample_fast_rcnn_targets
,
roi_align
,
fastrcnn_outputs
,
fastrcnn_losses
,
fastrcnn_predictions
,
fastrcnn_outputs
,
fastrcnn_losses
,
fastrcnn_predictions
,
maskrcnn_upXconv_head
,
maskrcnn_loss
,
maskrcnn_upXconv_head
,
maskrcnn_loss
,
fpn_model
,
fpn_map_rois_to_levels
,
fastrcnn_2fc_head
,
fpn_model
,
fastrcnn_2fc_head
,
multilevel_roi_align
)
multilevel_roi_align
)
from
data
import
(
from
data
import
(
get_train_dataflow
,
get_eval_dataflow
,
get_train_dataflow
,
get_eval_dataflow
,
get_all_anchors
,
get_all_anchors_fpn
)
get_all_anchors
,
get_all_anchors_fpn
)
...
@@ -62,6 +62,8 @@ def get_model_output_names():
...
@@ -62,6 +62,8 @@ def get_model_output_names():
def
get_model
():
def
get_model
():
if
config
.
MODE_FPN
:
if
config
.
MODE_FPN
:
if
get_tf_version
()
<
1.6
:
logger
.
warn
(
"FPN has chances to crash in TF<1.6, due to a TF issue."
)
return
ResNetFPNModel
()
return
ResNetFPNModel
()
else
:
else
:
return
ResNetC4Model
()
return
ResNetC4Model
()
...
@@ -223,8 +225,12 @@ class ResNetC4Model(DetectionModel):
...
@@ -223,8 +225,12 @@ class ResNetC4Model(DetectionModel):
ncls
=
config
.
NUM_CLASS
ncls
=
config
.
NUM_CLASS
return
tf
.
zeros
([
0
,
2048
,
7
,
7
]),
tf
.
zeros
([
0
,
ncls
]),
tf
.
zeros
([
0
,
ncls
-
1
,
4
])
return
tf
.
zeros
([
0
,
2048
,
7
,
7
]),
tf
.
zeros
([
0
,
ncls
]),
tf
.
zeros
([
0
,
ncls
-
1
,
4
])
feature_fastrcnn
,
fastrcnn_label_logits
,
fastrcnn_box_logits
=
tf
.
cond
(
if
get_tf_version_number
()
>=
1.6
:
tf
.
size
(
boxes_on_featuremap
)
>
0
,
ff_true
,
ff_false
)
feature_fastrcnn
,
fastrcnn_label_logits
,
fastrcnn_box_logits
=
ff_true
()
else
:
logger
.
warn
(
"This example may drop support for TF < 1.6 soon."
)
feature_fastrcnn
,
fastrcnn_label_logits
,
fastrcnn_box_logits
=
tf
.
cond
(
tf
.
size
(
boxes_on_featuremap
)
>
0
,
ff_true
,
ff_false
)
if
is_training
:
if
is_training
:
# rpn loss
# rpn loss
...
@@ -434,10 +440,11 @@ class ResNetFPNModel(DetectionModel):
...
@@ -434,10 +440,11 @@ class ResNetFPNModel(DetectionModel):
'maskrcnn'
,
roi_feature_maskrcnn
,
config
.
NUM_CLASS
,
4
)
# #fg x #cat x 28 x 28
'maskrcnn'
,
roi_feature_maskrcnn
,
config
.
NUM_CLASS
,
4
)
# #fg x #cat x 28 x 28
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
final_labels
)),
tf
.
to_int32
(
final_labels
)
-
1
],
axis
=
1
)
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
final_labels
)),
tf
.
to_int32
(
final_labels
)
-
1
],
axis
=
1
)
final_mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #resultx28x28
final_mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #resultx28x28
final_masks
=
tf
.
sigmoid
(
final_mask_logits
,
name
=
'final_masks'
)
tf
.
sigmoid
(
final_mask_logits
,
name
=
'final_masks'
)
def
visualize
(
model_path
,
nr_visualize
=
50
,
output_dir
=
'output'
):
def
visualize
(
model_path
,
nr_visualize
=
50
,
output_dir
=
'output'
):
assert
not
config
.
MODE_FPN
,
"FPN visualize is not supported yet!"
df
=
get_train_dataflow
()
# we don't visualize mask stuff
df
=
get_train_dataflow
()
# we don't visualize mask stuff
df
.
reset_state
()
df
.
reset_state
()
...
@@ -577,7 +584,7 @@ if __name__ == '__main__':
...
@@ -577,7 +584,7 @@ if __name__ == '__main__':
COCODetection
(
config
.
BASEDIR
,
'val2014'
)
# Only to load the class names into caches
COCODetection
(
config
.
BASEDIR
,
'val2014'
)
# Only to load the class names into caches
predict
(
pred
,
args
.
predict
)
predict
(
pred
,
args
.
predict
)
else
:
else
:
logger
.
set_logger_dir
(
args
.
logdir
,
'd'
)
logger
.
set_logger_dir
(
args
.
logdir
)
print_config
()
print_config
()
factor
=
get_batch_factor
()
factor
=
get_batch_factor
()
stepnum
=
config
.
STEPS_PER_EPOCH
stepnum
=
config
.
STEPS_PER_EPOCH
...
@@ -611,5 +618,5 @@ if __name__ == '__main__':
...
@@ -611,5 +618,5 @@ if __name__ == '__main__':
max_epoch
=
config
.
LR_SCHEDULE
[
-
1
]
*
factor
//
stepnum
,
max_epoch
=
config
.
LR_SCHEDULE
[
-
1
]
*
factor
//
stepnum
,
session_init
=
get_model_loader
(
args
.
load
)
if
args
.
load
else
None
,
session_init
=
get_model_loader
(
args
.
load
)
if
args
.
load
else
None
,
)
)
trainer
=
SyncMultiGPUTrainerReplicated
(
get_nr_gpu
())
trainer
=
SyncMultiGPUTrainerReplicated
(
get_nr_gpu
()
,
mode
=
'cpu'
)
launch_train_with_config
(
cfg
,
trainer
)
launch_train_with_config
(
cfg
,
trainer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment