Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
08b0dfb6
Commit
08b0dfb6
authored
Jul 01, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] split models into separate files
parent
70b70736
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
690 additions
and
9 deletions
+690
-9
examples/FasterRCNN/NOTES.md
examples/FasterRCNN/NOTES.md
+1
-1
examples/FasterRCNN/model_fpn.py
examples/FasterRCNN/model_fpn.py
+194
-0
examples/FasterRCNN/model_frcnn.py
examples/FasterRCNN/model_frcnn.py
+255
-0
examples/FasterRCNN/model_mrcnn.py
examples/FasterRCNN/model_mrcnn.py
+73
-0
examples/FasterRCNN/model_rpn.py
examples/FasterRCNN/model_rpn.py
+155
-0
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+12
-8
No files found.
examples/FasterRCNN/NOTES.md
View file @
08b0dfb6
...
@@ -6,7 +6,7 @@ This is a minimal implementation that simply contains these files:
...
@@ -6,7 +6,7 @@ This is a minimal implementation that simply contains these files:
+
common.py: common data preparation utilities
+
common.py: common data preparation utilities
+
basemodel.py: implement backbones
+
basemodel.py: implement backbones
+
model_box.py: implement box-related symbolic functions
+
model_box.py: implement box-related symbolic functions
+
model
.py: implement RPN/Faster-RCNN/FPN/Mask-RCNN
+
model
_{fpn,rpn,mrcnn,frcnn}.py: implement FPN,RPN,Mask-/Fast-RCNN models.
+
train.py: main training script
+
train.py: main training script
+
utils/: third-party helper functions
+
utils/: third-party helper functions
+
eval.py: evaluation utilities
+
eval.py: evaluation utilities
...
...
examples/FasterRCNN/model_fpn.py
0 → 100644
View file @
08b0dfb6
# -*- coding: utf-8 -*-
import
numpy
as
np
import
tensorflow
as
tf
import
itertools
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.models
import
(
Conv2D
,
layer_register
,
FixedUnPooling
,
MaxPooling
)
from
model_rpn
import
rpn_losses
,
generate_rpn_proposals
from
model_box
import
roi_align
from
utils.box_ops
import
area
as
tf_area
from
config
import
config
as
cfg
@
layer_register
(
log_shape
=
True
)
def
fpn_model
(
features
):
"""
Args:
features ([tf.Tensor]): ResNet features c2-c5
Returns:
[tf.Tensor]: FPN features p2-p6
"""
assert
len
(
features
)
==
4
,
features
num_channel
=
cfg
.
FPN
.
NUM_CHANNEL
def
upsample2x
(
name
,
x
):
return
FixedUnPooling
(
name
,
x
,
2
,
unpool_mat
=
np
.
ones
((
2
,
2
),
dtype
=
'float32'
),
data_format
=
'channels_first'
)
# tf.image.resize is, again, not aligned.
# with tf.name_scope(name):
# shape2d = tf.shape(x)[2:]
# x = tf.transpose(x, [0, 2, 3, 1])
# x = tf.image.resize_nearest_neighbor(x, shape2d * 2, align_corners=True)
# x = tf.transpose(x, [0, 3, 1, 2])
# return x
with
argscope
(
Conv2D
,
data_format
=
'channels_first'
,
activation
=
tf
.
identity
,
use_bias
=
True
,
kernel_initializer
=
tf
.
variance_scaling_initializer
(
scale
=
1.
)):
lat_2345
=
[
Conv2D
(
'lateral_1x1_c{}'
.
format
(
i
+
2
),
c
,
num_channel
,
1
)
for
i
,
c
in
enumerate
(
features
)]
lat_sum_5432
=
[]
for
idx
,
lat
in
enumerate
(
lat_2345
[::
-
1
]):
if
idx
==
0
:
lat_sum_5432
.
append
(
lat
)
else
:
lat
=
lat
+
upsample2x
(
'upsample_lat{}'
.
format
(
6
-
idx
),
lat_sum_5432
[
-
1
])
lat_sum_5432
.
append
(
lat
)
p2345
=
[
Conv2D
(
'posthoc_3x3_p{}'
.
format
(
i
+
2
),
c
,
num_channel
,
3
)
for
i
,
c
in
enumerate
(
lat_sum_5432
[::
-
1
])]
p6
=
MaxPooling
(
'maxpool_p6'
,
p2345
[
-
1
],
pool_size
=
1
,
strides
=
2
,
data_format
=
'channels_first'
)
return
p2345
+
[
p6
]
@
under_name_scope
()
def
fpn_map_rois_to_levels
(
boxes
):
"""
Assign boxes to level 2~5.
Args:
boxes (nx4):
Returns:
[tf.Tensor]: 4 tensors for level 2-5. Each tensor is a vector of indices of boxes in its level.
[tf.Tensor]: 4 tensors, the gathered boxes in each level.
Be careful that the returned tensor could be empty.
"""
sqrtarea
=
tf
.
sqrt
(
tf_area
(
boxes
))
level
=
tf
.
to_int32
(
tf
.
floor
(
4
+
tf
.
log
(
sqrtarea
*
(
1.
/
224
)
+
1e-6
)
*
(
1.0
/
np
.
log
(
2
))))
# RoI levels range from 2~5 (not 6)
level_ids
=
[
tf
.
where
(
level
<=
2
),
tf
.
where
(
tf
.
equal
(
level
,
3
)),
# == is not supported
tf
.
where
(
tf
.
equal
(
level
,
4
)),
tf
.
where
(
level
>=
5
)]
level_ids
=
[
tf
.
reshape
(
x
,
[
-
1
],
name
=
'roi_level{}_id'
.
format
(
i
+
2
))
for
i
,
x
in
enumerate
(
level_ids
)]
num_in_levels
=
[
tf
.
size
(
x
,
name
=
'num_roi_level{}'
.
format
(
i
+
2
))
for
i
,
x
in
enumerate
(
level_ids
)]
add_moving_summary
(
*
num_in_levels
)
level_boxes
=
[
tf
.
gather
(
boxes
,
ids
)
for
ids
in
level_ids
]
return
level_ids
,
level_boxes
@
under_name_scope
()
def
multilevel_roi_align
(
features
,
rcnn_boxes
,
resolution
):
"""
Args:
features ([tf.Tensor]): 4 FPN feature level 2-5
rcnn_boxes (tf.Tensor): nx4 boxes
resolution (int): output spatial resolution
Returns:
NxC x res x res
"""
assert
len
(
features
)
==
4
,
features
# Reassign rcnn_boxes to levels
level_ids
,
level_boxes
=
fpn_map_rois_to_levels
(
rcnn_boxes
)
all_rois
=
[]
# Crop patches from corresponding levels
for
i
,
boxes
,
featuremap
in
zip
(
itertools
.
count
(),
level_boxes
,
features
):
with
tf
.
name_scope
(
'roi_level{}'
.
format
(
i
+
2
)):
boxes_on_featuremap
=
boxes
*
(
1.0
/
cfg
.
FPN
.
ANCHOR_STRIDES
[
i
])
all_rois
.
append
(
roi_align
(
featuremap
,
boxes_on_featuremap
,
resolution
))
all_rois
=
tf
.
concat
(
all_rois
,
axis
=
0
)
# NCHW
# Unshuffle to the original order, to match the original samples
level_id_perm
=
tf
.
concat
(
level_ids
,
axis
=
0
)
# A permutation of 1~N
level_id_invert_perm
=
tf
.
invert_permutation
(
level_id_perm
)
all_rois
=
tf
.
gather
(
all_rois
,
level_id_invert_perm
)
return
all_rois
def
multilevel_rpn_losses
(
multilevel_anchors
,
multilevel_label_logits
,
multilevel_box_logits
):
"""
Args:
multilevel_anchors: #lvl RPNAnchors
multilevel_label_logits: #lvl tensors of shape HxWxA
multilevel_box_logits: #lvl tensors of shape HxWxAx4
Returns:
label_loss, box_loss
"""
num_lvl
=
len
(
cfg
.
FPN
.
ANCHOR_STRIDES
)
assert
len
(
multilevel_anchors
)
==
num_lvl
assert
len
(
multilevel_label_logits
)
==
num_lvl
assert
len
(
multilevel_box_logits
)
==
num_lvl
losses
=
[]
for
lvl
in
range
(
num_lvl
):
with
tf
.
name_scope
(
'RPNLoss_Lvl{}'
.
format
(
lvl
+
2
)):
anchors
=
multilevel_anchors
[
lvl
]
label_loss
,
box_loss
=
rpn_losses
(
anchors
.
gt_labels
,
anchors
.
encoded_gt_boxes
(),
multilevel_label_logits
[
lvl
],
multilevel_box_logits
[
lvl
])
losses
.
extend
([
label_loss
,
box_loss
])
with
tf
.
name_scope
(
'rpn_losses'
):
total_label_loss
=
tf
.
add_n
(
losses
[::
2
],
name
=
'label_loss'
)
total_box_loss
=
tf
.
add_n
(
losses
[
1
::
2
],
name
=
'box_loss'
)
add_moving_summary
(
total_label_loss
,
total_box_loss
)
return
total_label_loss
,
total_box_loss
def
generate_fpn_proposals
(
multilevel_anchors
,
multilevel_label_logits
,
multilevel_box_logits
,
image_shape2d
,
pre_nms_topk
,
post_nms_topk
):
"""
Args:
multilevel_anchors: #lvl RPNAnchors
multilevel_label_logits: #lvl tensors of shape HxWxA
multilevel_box_logits: #lvl tensors of shape HxWxAx4
Returns:
boxes: kx4 float
scores: k logits
"""
num_lvl
=
len
(
cfg
.
FPN
.
ANCHOR_STRIDES
)
assert
len
(
multilevel_anchors
)
==
num_lvl
assert
len
(
multilevel_label_logits
)
==
num_lvl
assert
len
(
multilevel_box_logits
)
==
num_lvl
all_boxes
=
[]
all_scores
=
[]
for
lvl
in
range
(
num_lvl
):
with
tf
.
name_scope
(
'FPNProposal_Lvl{}'
.
format
(
lvl
+
2
)):
anchors
=
multilevel_anchors
[
lvl
]
pred_boxes_decoded
=
anchors
.
decode_logits
(
multilevel_box_logits
[
lvl
])
proposal_boxes
,
proposal_scores
=
generate_rpn_proposals
(
tf
.
reshape
(
pred_boxes_decoded
,
[
-
1
,
4
]),
tf
.
reshape
(
multilevel_label_logits
[
lvl
],
[
-
1
]),
image_shape2d
,
pre_nms_topk
)
all_boxes
.
append
(
proposal_boxes
)
all_scores
.
append
(
proposal_scores
)
proposal_boxes
=
tf
.
concat
(
all_boxes
,
axis
=
0
)
# nx4
proposal_scores
=
tf
.
concat
(
all_scores
,
axis
=
0
)
# n
proposal_topk
=
tf
.
minimum
(
tf
.
size
(
proposal_scores
),
post_nms_topk
)
proposal_scores
,
topk_indices
=
tf
.
nn
.
top_k
(
proposal_scores
,
k
=
proposal_topk
,
sorted
=
False
)
proposal_boxes
=
tf
.
gather
(
proposal_boxes
,
topk_indices
)
return
proposal_boxes
,
proposal_scores
examples/FasterRCNN/model.py
→
examples/FasterRCNN/model
_frcnn
.py
View file @
08b0dfb6
This diff is collapsed.
Click to expand it.
examples/FasterRCNN/model_mrcnn.py
0 → 100644
View file @
08b0dfb6
# -*- coding: utf-8 -*-
import
tensorflow
as
tf
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.models
import
(
Conv2D
,
layer_register
,
Conv2DTranspose
)
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
config
import
config
as
cfg
@
layer_register
(
log_shape
=
True
)
def
maskrcnn_upXconv_head
(
feature
,
num_category
,
num_convs
):
"""
Args:
feature (NxCx s x s): size is 7 in C4 models and 14 in FPN models.
num_category(int):
num_convs (int): number of convolution layers
Returns:
mask_logits (N x num_category x 2s x 2s):
"""
l
=
feature
with
argscope
([
Conv2D
,
Conv2DTranspose
],
data_format
=
'channels_first'
,
kernel_initializer
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
,
mode
=
'fan_out'
,
distribution
=
'normal'
)):
# c2's MSRAFill is fan_out
for
k
in
range
(
num_convs
):
l
=
Conv2D
(
'fcn{}'
.
format
(
k
),
l
,
cfg
.
MRCNN
.
HEAD_DIM
,
3
,
activation
=
tf
.
nn
.
relu
)
l
=
Conv2DTranspose
(
'deconv'
,
l
,
cfg
.
MRCNN
.
HEAD_DIM
,
2
,
strides
=
2
,
activation
=
tf
.
nn
.
relu
)
l
=
Conv2D
(
'conv'
,
l
,
num_category
,
1
)
return
l
@
under_name_scope
()
def
maskrcnn_loss
(
mask_logits
,
fg_labels
,
fg_target_masks
):
"""
Args:
mask_logits: #fg x #category xhxw
fg_labels: #fg, in 1~#class
fg_target_masks: #fgxhxw, int
"""
num_fg
=
tf
.
size
(
fg_labels
)
indices
=
tf
.
stack
([
tf
.
range
(
num_fg
),
tf
.
to_int32
(
fg_labels
)
-
1
],
axis
=
1
)
# #fgx2
mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #fgxhxw
mask_probs
=
tf
.
sigmoid
(
mask_logits
)
# add some training visualizations to tensorboard
with
tf
.
name_scope
(
'mask_viz'
):
viz
=
tf
.
concat
([
fg_target_masks
,
mask_probs
],
axis
=
1
)
viz
=
tf
.
expand_dims
(
viz
,
3
)
viz
=
tf
.
cast
(
viz
*
255
,
tf
.
uint8
,
name
=
'viz'
)
tf
.
summary
.
image
(
'mask_truth|pred'
,
viz
,
max_outputs
=
10
)
loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
fg_target_masks
,
logits
=
mask_logits
)
loss
=
tf
.
reduce_mean
(
loss
,
name
=
'maskrcnn_loss'
)
pred_label
=
mask_probs
>
0.5
truth_label
=
fg_target_masks
>
0.5
accuracy
=
tf
.
reduce_mean
(
tf
.
to_float
(
tf
.
equal
(
pred_label
,
truth_label
)),
name
=
'accuracy'
)
pos_accuracy
=
tf
.
logical_and
(
tf
.
equal
(
pred_label
,
truth_label
),
tf
.
equal
(
truth_label
,
True
))
pos_accuracy
=
tf
.
reduce_mean
(
tf
.
to_float
(
pos_accuracy
),
name
=
'pos_accuracy'
)
fg_pixel_ratio
=
tf
.
reduce_mean
(
tf
.
to_float
(
truth_label
),
name
=
'fg_pixel_ratio'
)
add_moving_summary
(
loss
,
accuracy
,
fg_pixel_ratio
,
pos_accuracy
)
return
loss
examples/FasterRCNN/model_rpn.py
0 → 100644
View file @
08b0dfb6
# -*- coding: utf-8 -*-
import
tensorflow
as
tf
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
,
auto_reuse_variable_scope
from
tensorpack.models
import
Conv2D
,
layer_register
from
model_box
import
clip_boxes
from
config
import
config
as
cfg
@
layer_register
(
log_shape
=
True
)
@
auto_reuse_variable_scope
def
rpn_head
(
featuremap
,
channel
,
num_anchors
):
"""
Returns:
label_logits: fHxfWxNA
box_logits: fHxfWxNAx4
"""
with
argscope
(
Conv2D
,
data_format
=
'channels_first'
,
kernel_initializer
=
tf
.
random_normal_initializer
(
stddev
=
0.01
)):
hidden
=
Conv2D
(
'conv0'
,
featuremap
,
channel
,
3
,
activation
=
tf
.
nn
.
relu
)
label_logits
=
Conv2D
(
'class'
,
hidden
,
num_anchors
,
1
)
box_logits
=
Conv2D
(
'box'
,
hidden
,
4
*
num_anchors
,
1
)
# 1, NA(*4), im/16, im/16 (NCHW)
label_logits
=
tf
.
transpose
(
label_logits
,
[
0
,
2
,
3
,
1
])
# 1xfHxfWxNA
label_logits
=
tf
.
squeeze
(
label_logits
,
0
)
# fHxfWxNA
shp
=
tf
.
shape
(
box_logits
)
# 1x(NAx4)xfHxfW
box_logits
=
tf
.
transpose
(
box_logits
,
[
0
,
2
,
3
,
1
])
# 1xfHxfWx(NAx4)
box_logits
=
tf
.
reshape
(
box_logits
,
tf
.
stack
([
shp
[
2
],
shp
[
3
],
num_anchors
,
4
]))
# fHxfWxNAx4
return
label_logits
,
box_logits
@
under_name_scope
()
def
rpn_losses
(
anchor_labels
,
anchor_boxes
,
label_logits
,
box_logits
):
"""
Args:
anchor_labels: fHxfWxNA
anchor_boxes: fHxfWxNAx4, encoded
label_logits: fHxfWxNA
box_logits: fHxfWxNAx4
Returns:
label_loss, box_loss
"""
with
tf
.
device
(
'/cpu:0'
):
valid_mask
=
tf
.
stop_gradient
(
tf
.
not_equal
(
anchor_labels
,
-
1
))
pos_mask
=
tf
.
stop_gradient
(
tf
.
equal
(
anchor_labels
,
1
))
nr_valid
=
tf
.
stop_gradient
(
tf
.
count_nonzero
(
valid_mask
,
dtype
=
tf
.
int32
),
name
=
'num_valid_anchor'
)
nr_pos
=
tf
.
identity
(
tf
.
count_nonzero
(
pos_mask
,
dtype
=
tf
.
int32
),
name
=
'num_pos_anchor'
)
# nr_pos is guaranteed >0 in C4. But in FPN. even nr_valid could be 0.
valid_anchor_labels
=
tf
.
boolean_mask
(
anchor_labels
,
valid_mask
)
valid_label_logits
=
tf
.
boolean_mask
(
label_logits
,
valid_mask
)
with
tf
.
name_scope
(
'label_metrics'
):
valid_label_prob
=
tf
.
nn
.
sigmoid
(
valid_label_logits
)
summaries
=
[]
with
tf
.
device
(
'/cpu:0'
):
for
th
in
[
0.5
,
0.2
,
0.1
]:
valid_prediction
=
tf
.
cast
(
valid_label_prob
>
th
,
tf
.
int32
)
nr_pos_prediction
=
tf
.
reduce_sum
(
valid_prediction
,
name
=
'num_pos_prediction'
)
pos_prediction_corr
=
tf
.
count_nonzero
(
tf
.
logical_and
(
valid_label_prob
>
th
,
tf
.
equal
(
valid_prediction
,
valid_anchor_labels
)),
dtype
=
tf
.
int32
)
placeholder
=
0.5
# A small value will make summaries appear lower.
recall
=
tf
.
to_float
(
tf
.
truediv
(
pos_prediction_corr
,
nr_pos
))
recall
=
tf
.
where
(
tf
.
equal
(
nr_pos
,
0
),
placeholder
,
recall
,
name
=
'recall_th{}'
.
format
(
th
))
precision
=
tf
.
to_float
(
tf
.
truediv
(
pos_prediction_corr
,
nr_pos_prediction
))
precision
=
tf
.
where
(
tf
.
equal
(
nr_pos_prediction
,
0
),
placeholder
,
precision
,
name
=
'precision_th{}'
.
format
(
th
))
summaries
.
extend
([
precision
,
recall
])
add_moving_summary
(
*
summaries
)
# Per-level loss summaries in FPN may appear lower due to the use of a small placeholder.
# But the total loss is still the same. TODO make the summary op smarter
placeholder
=
0.
label_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
tf
.
to_float
(
valid_anchor_labels
),
logits
=
valid_label_logits
)
label_loss
=
tf
.
reduce_sum
(
label_loss
)
*
(
1.
/
cfg
.
RPN
.
BATCH_PER_IM
)
label_loss
=
tf
.
where
(
tf
.
equal
(
nr_valid
,
0
),
placeholder
,
label_loss
,
name
=
'label_loss'
)
pos_anchor_boxes
=
tf
.
boolean_mask
(
anchor_boxes
,
pos_mask
)
pos_box_logits
=
tf
.
boolean_mask
(
box_logits
,
pos_mask
)
delta
=
1.0
/
9
box_loss
=
tf
.
losses
.
huber_loss
(
pos_anchor_boxes
,
pos_box_logits
,
delta
=
delta
,
reduction
=
tf
.
losses
.
Reduction
.
SUM
)
/
delta
box_loss
=
box_loss
*
(
1.
/
cfg
.
RPN
.
BATCH_PER_IM
)
box_loss
=
tf
.
where
(
tf
.
equal
(
nr_pos
,
0
),
placeholder
,
box_loss
,
name
=
'box_loss'
)
add_moving_summary
(
label_loss
,
box_loss
,
nr_valid
,
nr_pos
)
return
label_loss
,
box_loss
@
under_name_scope
()
def
generate_rpn_proposals
(
boxes
,
scores
,
img_shape
,
pre_nms_topk
,
post_nms_topk
=
None
):
"""
Sample RPN proposals by the following steps:
1. Pick top k1 by scores
2. NMS them
3. Pick top k2 by scores. Default k2 == k1, i.e. does not filter the NMS output.
Args:
boxes: nx4 float dtype, the proposal boxes. Decoded to floatbox already
scores: n float, the logits
img_shape: [h, w]
pre_nms_topk, post_nms_topk (int): See above.
Returns:
boxes: kx4 float
scores: k logits
"""
assert
boxes
.
shape
.
ndims
==
2
,
boxes
.
shape
if
post_nms_topk
is
None
:
post_nms_topk
=
pre_nms_topk
topk
=
tf
.
minimum
(
pre_nms_topk
,
tf
.
size
(
scores
))
topk_scores
,
topk_indices
=
tf
.
nn
.
top_k
(
scores
,
k
=
topk
,
sorted
=
False
)
topk_boxes
=
tf
.
gather
(
boxes
,
topk_indices
)
topk_boxes
=
clip_boxes
(
topk_boxes
,
img_shape
)
topk_boxes_x1y1x2y2
=
tf
.
reshape
(
topk_boxes
,
(
-
1
,
2
,
2
))
topk_boxes_x1y1
,
topk_boxes_x2y2
=
tf
.
split
(
topk_boxes_x1y1x2y2
,
2
,
axis
=
1
)
# nx1x2 each
wbhb
=
tf
.
squeeze
(
topk_boxes_x2y2
-
topk_boxes_x1y1
,
axis
=
1
)
valid
=
tf
.
reduce_all
(
wbhb
>
cfg
.
RPN
.
MIN_SIZE
,
axis
=
1
)
# n,
topk_valid_boxes_x1y1x2y2
=
tf
.
boolean_mask
(
topk_boxes_x1y1x2y2
,
valid
)
topk_valid_scores
=
tf
.
boolean_mask
(
topk_scores
,
valid
)
# TODO not needed
topk_valid_boxes_y1x1y2x2
=
tf
.
reshape
(
tf
.
reverse
(
topk_valid_boxes_x1y1x2y2
,
axis
=
[
2
]),
(
-
1
,
4
),
name
=
'nms_input_boxes'
)
nms_indices
=
tf
.
image
.
non_max_suppression
(
topk_valid_boxes_y1x1y2x2
,
# TODO use exp to work around a bug in TF1.9: https://github.com/tensorflow/tensorflow/issues/19578
tf
.
exp
(
topk_valid_scores
),
max_output_size
=
post_nms_topk
,
iou_threshold
=
cfg
.
RPN
.
PROPOSAL_NMS_THRESH
)
topk_valid_boxes
=
tf
.
reshape
(
topk_valid_boxes_x1y1x2y2
,
(
-
1
,
4
))
final_boxes
=
tf
.
gather
(
topk_valid_boxes
,
nms_indices
)
final_scores
=
tf
.
gather
(
topk_valid_scores
,
nms_indices
)
tf
.
sigmoid
(
final_scores
,
name
=
'probs'
)
# for visualization
return
tf
.
stop_gradient
(
final_boxes
,
name
=
'boxes'
),
tf
.
stop_gradient
(
final_scores
,
name
=
'scores'
)
examples/FasterRCNN/train.py
View file @
08b0dfb6
...
@@ -29,16 +29,20 @@ from coco import COCODetection
...
@@ -29,16 +29,20 @@ from coco import COCODetection
from
basemodel
import
(
from
basemodel
import
(
image_preprocess
,
resnet_c4_backbone
,
resnet_conv5
,
image_preprocess
,
resnet_c4_backbone
,
resnet_conv5
,
resnet_fpn_backbone
)
resnet_fpn_backbone
)
import
model
from
model
import
(
import
model_frcnn
rpn_head
,
rpn_losses
,
from
model_frcnn
import
(
generate_rpn_proposals
,
sample_fast_rcnn_targets
,
sample_fast_rcnn_targets
,
fastrcnn_outputs
,
fastrcnn_losses
,
fastrcnn_predictions
,
fastrcnn_outputs
,
fastrcnn_losses
,
fastrcnn_predictions
)
maskrcnn_upXconv_head
,
maskrcnn_loss
,
from
model_mrcnn
import
maskrcnn_upXconv_head
,
maskrcnn_loss
fpn_model
,
multilevel_roi_align
,
multilevel_rpn_losses
,
generate_fpn_proposals
)
from
model_rpn
import
rpn_head
,
rpn_losses
,
generate_rpn_proposals
from
model_fpn
import
(
fpn_model
,
multilevel_roi_align
,
multilevel_rpn_losses
,
generate_fpn_proposals
)
from
model_box
import
(
from
model_box
import
(
clip_boxes
,
decode_bbox_target
,
encode_bbox_target
,
clip_boxes
,
decode_bbox_target
,
encode_bbox_target
,
crop_and_resize
,
roi_align
,
RPNAnchors
)
crop_and_resize
,
roi_align
,
RPNAnchors
)
from
data
import
(
from
data
import
(
get_train_dataflow
,
get_eval_dataflow
,
get_train_dataflow
,
get_eval_dataflow
,
get_all_anchors
,
get_all_anchors_fpn
)
get_all_anchors
,
get_all_anchors_fpn
)
...
@@ -328,7 +332,7 @@ class ResNetFPNModel(DetectionModel):
...
@@ -328,7 +332,7 @@ class ResNetFPNModel(DetectionModel):
roi_feature_fastrcnn
=
multilevel_roi_align
(
p23456
[:
4
],
rcnn_boxes
,
7
)
roi_feature_fastrcnn
=
multilevel_roi_align
(
p23456
[:
4
],
rcnn_boxes
,
7
)
fastrcnn_head_func
=
getattr
(
model
,
cfg
.
FPN
.
FRCNN_HEAD_FUNC
)
fastrcnn_head_func
=
getattr
(
model
_frcnn
,
cfg
.
FPN
.
FRCNN_HEAD_FUNC
)
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head_func
(
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head_func
(
'fastrcnn'
,
roi_feature_fastrcnn
,
cfg
.
DATA
.
NUM_CLASS
)
'fastrcnn'
,
roi_feature_fastrcnn
,
cfg
.
DATA
.
NUM_CLASS
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment