Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e3f463ab
Commit
e3f463ab
authored
May 18, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
FPN+mask
parent
4f13f971
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
105 additions
and
51 deletions
+105
-51
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+3
-2
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+0
-1
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+5
-2
examples/FasterRCNN/model.py
examples/FasterRCNN/model.py
+59
-23
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+38
-23
No files found.
examples/FasterRCNN/config.py
View file @
e3f463ab
...
...
@@ -4,7 +4,7 @@
import
numpy
as
np
# mode flags ---------------------
MODE_MASK
=
Fals
e
MODE_MASK
=
Tru
e
# dataset -----------------------
BASEDIR
=
'/path/to/your/COCO/DIR'
...
...
@@ -25,7 +25,7 @@ WARMUP = 1000 # in steps
STEPS_PER_EPOCH
=
500
LR_SCHEDULE
=
[
150000
,
230000
,
280000
]
LR_SCHEDULE
=
[
120000
,
160000
,
180000
]
# "1x" schedule in detectron
LR_SCHEDULE
=
[
240000
,
320000
,
360000
]
# "2x" schedule in detectron
#
LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# image resolution --------------------
SHORT_EDGE_SIZE
=
800
...
...
@@ -73,6 +73,7 @@ RESULTS_PER_IM = 100
# TODO Not Functioning. Don't USE
MODE_FPN
=
True
FPN_NUM_CHANNEL
=
256
MASKRCNN_HEAD_DIM
=
256
FASTRCNN_FC_HEAD_DIM
=
1024
FPN_RESOLUTION_REQUIREMENT
=
32
TRAIN_FPN_NMS_TOPK
=
2000
...
...
examples/FasterRCNN/data.py
View file @
e3f463ab
...
...
@@ -344,7 +344,6 @@ def get_train_dataflow(add_mask=False):
return
ret
ds
=
MultiProcessMapDataZMQ
(
ds
,
10
,
preprocess
)
#ds = PrefetchDataZMQ(ds, 3)
return
ds
...
...
examples/FasterRCNN/eval.py
View file @
e3f463ab
...
...
@@ -141,12 +141,15 @@ def print_evaluation_scores(json_file):
cocoEval
.
evaluate
()
cocoEval
.
accumulate
()
cocoEval
.
summarize
()
ret
[
'mAP(bbox)'
]
=
cocoEval
.
stats
[
0
]
fields
=
[
'IoU=0.5:0.95'
,
'IoU=0.5'
,
'IoU=0.75'
,
'small'
,
'medium'
,
'large'
]
for
k
in
range
(
6
):
ret
[
'mAP(bbox)/'
+
fields
[
k
]]
=
cocoEval
.
stat
[
k
]
if
config
.
MODE_MASK
:
cocoEval
=
COCOeval
(
coco
,
cocoDt
,
'segm'
)
cocoEval
.
evaluate
()
cocoEval
.
accumulate
()
cocoEval
.
summarize
()
ret
[
'mAP(segm)'
]
=
cocoEval
.
stats
[
0
]
for
k
in
range
(
6
):
ret
[
'mAP(segm)/'
+
fields
[
k
]]
=
cocoEval
.
stats
[
k
]
return
ret
examples/FasterRCNN/model.py
View file @
e3f463ab
...
...
@@ -3,6 +3,8 @@
import
tensorflow
as
tf
import
numpy
as
np
import
itertools
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
,
auto_reuse_variable_scope
...
...
@@ -371,22 +373,22 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
@
under_name_scope
()
def
roi_align
(
featuremap
,
boxes
,
output_shape
):
def
roi_align
(
featuremap
,
boxes
,
resolution
):
"""
Args:
featuremap: 1xCxHxW
boxes: Nx4 floatbox
output_shape: int
resolution: output spatial resolution
Returns:
NxCx
oHxoW
NxCx
res x res
"""
boxes
=
tf
.
stop_gradient
(
boxes
)
# TODO
# sample 4 locations per roi bin
ret
=
crop_and_resize
(
featuremap
,
boxes
,
tf
.
zeros
([
tf
.
shape
(
boxes
)[
0
]],
dtype
=
tf
.
int32
),
output_shape
*
2
)
resolution
*
2
)
ret
=
tf
.
nn
.
avg_pool
(
ret
,
[
1
,
1
,
2
,
2
],
[
1
,
1
,
2
,
2
],
padding
=
'SAME'
,
data_format
=
'NCHW'
)
return
ret
...
...
@@ -411,6 +413,25 @@ def fastrcnn_outputs(feature, num_classes):
return
classification
,
box_regression
@
layer_register
(
log_shape
=
True
)
def
fastrcnn_2fc_head
(
feature
,
num_classes
):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
dim
=
config
.
FASTRCNN_FC_HEAD_DIM
logger
.
info
(
"fc-head-xavier-fanin"
)
#init = tf.random_normal_initializer(stddev=0.01)
init
=
tf
.
variance_scaling_initializer
()
hidden
=
FullyConnected
(
'fc6'
,
feature
,
dim
,
kernel_initializer
=
init
,
nl
=
tf
.
nn
.
relu
)
hidden
=
FullyConnected
(
'fc7'
,
hidden
,
dim
,
kernel_initializer
=
init
,
nl
=
tf
.
nn
.
relu
)
return
fastrcnn_outputs
(
'outputs'
,
hidden
,
num_classes
)
@
under_name_scope
()
def
fastrcnn_losses
(
labels
,
label_logits
,
fg_boxes
,
fg_box_logits
):
"""
...
...
@@ -508,20 +529,24 @@ def fastrcnn_predictions(boxes, probs):
@
layer_register
(
log_shape
=
True
)
def
maskrcnn_
head
(
feature
,
num_clas
s
):
def
maskrcnn_
upXconv_head
(
feature
,
num_class
,
num_conv
s
):
"""
Args:
feature (NxCx
7x7):
feature (NxCx
s x s): size is 7 in C4 models and 14 in FPN models.
num_classes(int): num_category + 1
num_convs (int): number of convolution layers
Returns:
mask_logits (N x num_category x
14 x 14
):
mask_logits (N x num_category x
2s x 2s
):
"""
l
=
feature
with
argscope
([
Conv2D
,
Conv2DTranspose
],
data_format
=
'channels_first'
,
kernel_initializer
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
,
mode
=
'fan_out'
,
distribution
=
'normal'
)):
# c2's MSRAFill is fan_out
l
=
Conv2DTranspose
(
'deconv'
,
feature
,
256
,
2
,
strides
=
2
,
activation
=
tf
.
nn
.
relu
)
for
k
in
range
(
num_convs
):
l
=
Conv2D
(
'fcn{}'
.
format
(
k
),
l
,
config
.
MASKRCNN_HEAD_DIM
,
3
,
activation
=
tf
.
nn
.
relu
)
l
=
Conv2DTranspose
(
'deconv'
,
l
,
config
.
MASKRCNN_HEAD_DIM
,
2
,
strides
=
2
,
activation
=
tf
.
nn
.
relu
)
l
=
Conv2D
(
'conv'
,
l
,
num_class
-
1
,
1
)
return
l
...
...
@@ -530,13 +555,13 @@ def maskrcnn_head(feature, num_class):
def
maskrcnn_loss
(
mask_logits
,
fg_labels
,
fg_target_masks
):
"""
Args:
mask_logits: #fg x #category x
14x14
mask_logits: #fg x #category x
hxw
fg_labels: #fg, in 1~#class
fg_target_masks: #fgx
14x14
, int
fg_target_masks: #fgx
hxw
, int
"""
num_fg
=
tf
.
size
(
fg_labels
)
indices
=
tf
.
stack
([
tf
.
range
(
num_fg
),
tf
.
to_int32
(
fg_labels
)
-
1
],
axis
=
1
)
# #fgx2
mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #fgx
14x14
mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #fgx
hxw
mask_probs
=
tf
.
sigmoid
(
mask_logits
)
# add some training visualizations to tensorboard
...
...
@@ -642,22 +667,33 @@ def fpn_map_rois_to_levels(boxes):
return
level_ids
,
level_boxes
@
layer_register
(
log_shape
=
True
)
def
fastrcnn_2fc_head
(
feature
,
dim
,
num_classes
):
@
under_name_scope
(
)
def
multilevel_roi_align
(
features
,
rcnn_boxes
,
resolution
):
"""
Args:
feature (any shape):
dim (int): mlp dim
num_classes(int): num_category + 1
features ([tf.Tensor]): 4 FPN feature level 2-5
rcnn_boxes (tf.Tensor): nx4 boxes
resolution (int): output spatial resolution
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
NxC x res x res
"""
logger
.
info
(
"fc-head-stddev=0.01"
)
init
=
tf
.
random_normal_initializer
(
stddev
=
0.01
)
hidden
=
FullyConnected
(
'fc6'
,
feature
,
dim
,
kernel_initializer
=
init
,
nl
=
tf
.
nn
.
relu
)
hidden
=
FullyConnected
(
'fc7'
,
hidden
,
dim
,
kernel_initializer
=
init
,
nl
=
tf
.
nn
.
relu
)
return
fastrcnn_outputs
(
'outputs'
,
hidden
,
num_classes
)
assert
len
(
features
)
==
4
,
features
# Reassign rcnn_boxes to levels
level_ids
,
level_boxes
=
fpn_map_rois_to_levels
(
rcnn_boxes
)
all_rois
=
[]
# Crop patches from corresponding levels
for
i
,
boxes
,
featuremap
in
zip
(
itertools
.
count
(),
level_boxes
,
features
):
with
tf
.
name_scope
(
'roi_level{}'
.
format
(
i
+
2
)):
boxes_on_featuremap
=
boxes
*
(
1.0
/
config
.
ANCHOR_STRIDES_FPN
[
i
])
all_rois
.
append
(
roi_align
(
featuremap
,
boxes_on_featuremap
,
resolution
))
all_rois
=
tf
.
concat
(
all_rois
,
axis
=
0
)
# NCHW
# Unshuffle to the original order, to match the original samples
level_id_perm
=
tf
.
concat
(
level_ids
,
axis
=
0
)
# A permutation of 1~N
level_id_invert_perm
=
tf
.
invert_permutation
(
level_id_perm
)
all_rois
=
tf
.
gather
(
all_rois
,
level_id_invert_perm
)
return
all_rois
if
__name__
==
'__main__'
:
...
...
examples/FasterRCNN/train.py
View file @
e3f463ab
...
...
@@ -32,8 +32,9 @@ from model import (
rpn_head
,
rpn_losses
,
generate_rpn_proposals
,
sample_fast_rcnn_targets
,
roi_align
,
fastrcnn_outputs
,
fastrcnn_losses
,
fastrcnn_predictions
,
maskrcnn_head
,
maskrcnn_loss
,
fpn_model
,
fpn_map_rois_to_levels
,
fastrcnn_2fc_head
)
maskrcnn_upXconv_head
,
maskrcnn_loss
,
fpn_model
,
fpn_map_rois_to_levels
,
fastrcnn_2fc_head
,
multilevel_roi_align
)
from
data
import
(
get_train_dataflow
,
get_eval_dataflow
,
get_all_anchors
,
get_all_anchors_fpn
)
...
...
@@ -245,11 +246,12 @@ class ResNetC4Model(DetectionModel):
fg_labels
=
tf
.
gather
(
rcnn_labels
,
fg_inds_wrt_sample
)
# In training, mask branch shares the same C5 feature.
fg_feature
=
tf
.
gather
(
feature_fastrcnn
,
fg_inds_wrt_sample
)
mask_logits
=
maskrcnn_head
(
'maskrcnn'
,
fg_feature
,
config
.
NUM_CLASS
)
# #fg x #cat x 14x14
mask_logits
=
maskrcnn_upXconv_head
(
'maskrcnn'
,
fg_feature
,
config
.
NUM_CLASS
,
0
)
# #fg x #cat x 14x14
gt_masks_for_fg
=
tf
.
gather
(
gt_masks
,
fg_inds_wrt_gt
)
# nfg x H x W
matched_gt_masks
=
tf
.
gather
(
gt_masks
,
fg_inds_wrt_gt
)
# nfg x H x W
target_masks_for_fg
=
crop_and_resize
(
tf
.
expand_dims
(
gt_masks_for_fg
,
1
),
tf
.
expand_dims
(
matched_gt_masks
,
1
),
fg_sampled_boxes
,
tf
.
range
(
tf
.
size
(
fg_inds_wrt_gt
)),
14
,
pad_border
=
False
)
# nfg x 1x14x14
...
...
@@ -279,8 +281,8 @@ class ResNetC4Model(DetectionModel):
def
f1
():
roi_resized
=
roi_align
(
featuremap
,
final_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
),
14
)
feature_maskrcnn
=
resnet_conv5
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
])
mask_logits
=
maskrcnn_head
(
'maskrcnn'
,
feature_maskrcnn
,
config
.
NUM_CLASS
)
# #result x #cat x 14x14
mask_logits
=
maskrcnn_
upXconv_
head
(
'maskrcnn'
,
feature_maskrcnn
,
config
.
NUM_CLASS
,
0
)
# #result x #cat x 14x14
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
final_labels
)),
tf
.
to_int32
(
final_labels
)
-
1
],
axis
=
1
)
final_mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #resultx14x14
return
tf
.
sigmoid
(
final_mask_logits
)
...
...
@@ -370,25 +372,13 @@ class ResNetFPNModel(DetectionModel):
# The boxes to be used to crop RoIs.
rcnn_boxes
=
proposal_boxes
# Reassign rcnn_boxes to levels
level_ids
,
level_boxes
=
fpn_map_rois_to_levels
(
rcnn_boxes
)
all_rois
=
[]
# Crop patches from corresponding levels
for
i
,
boxes
,
featuremap
in
zip
(
itertools
.
count
(),
level_boxes
,
p23456
[:
4
]):
with
tf
.
name_scope
(
'roi_level{}'
.
format
(
i
+
2
)):
boxes_on_featuremap
=
boxes
*
(
1.0
/
config
.
ANCHOR_STRIDES_FPN
[
i
])
all_rois
.
append
(
roi_align
(
featuremap
,
boxes_on_featuremap
,
7
))
all_rois
=
tf
.
concat
(
all_rois
,
axis
=
0
)
# NCHW
# Unshuffle to the original order, to match the original samples
level_id_perm
=
tf
.
concat
(
level_ids
,
axis
=
0
)
# A permutation of 1~N
level_id_invert_perm
=
tf
.
invert_permutation
(
level_id_perm
)
all_rois
=
tf
.
gather
(
all_rois
,
level_id_invert_perm
)
roi_feature_fastrcnn
=
multilevel_roi_align
(
p23456
[:
4
],
rcnn_boxes
,
7
)
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_2fc_head
(
'fastrcnn'
,
all_rois
,
config
.
FASTRCNN_FC_HEAD_DIM
,
config
.
NUM_CLASS
)
'fastrcnn'
,
roi_feature_fastrcnn
,
config
.
NUM_CLASS
)
if
is_training
:
# rpn loss ...
with
tf
.
name_scope
(
'rpn_losses'
):
rpn_total_label_loss
=
tf
.
add_n
(
rpn_loss_collection
[::
2
],
name
=
'label_loss'
)
rpn_total_box_loss
=
tf
.
add_n
(
rpn_loss_collection
[
1
::
2
],
name
=
'box_loss'
)
...
...
@@ -405,7 +395,24 @@ class ResNetFPNModel(DetectionModel):
image
,
rcnn_labels
,
fg_sampled_boxes
,
matched_gt_boxes
,
fastrcnn_label_logits
,
fg_fastrcnn_box_logits
)
mrcnn_loss
=
0.0
if
config
.
MODE_MASK
:
# maskrcnn loss
fg_labels
=
tf
.
gather
(
rcnn_labels
,
fg_inds_wrt_sample
)
roi_feature_maskrcnn
=
multilevel_roi_align
(
p23456
[:
4
],
fg_sampled_boxes
,
14
)
mask_logits
=
maskrcnn_upXconv_head
(
'maskrcnn'
,
roi_feature_maskrcnn
,
config
.
NUM_CLASS
,
4
)
# #fg x #cat x 28 x 28
matched_gt_masks
=
tf
.
gather
(
gt_masks
,
fg_inds_wrt_gt
)
# fg x H x W
target_masks_for_fg
=
crop_and_resize
(
tf
.
expand_dims
(
matched_gt_masks
,
1
),
fg_sampled_boxes
,
tf
.
range
(
tf
.
size
(
fg_inds_wrt_gt
)),
28
,
pad_border
=
False
)
# fg x 1x28x28
target_masks_for_fg
=
tf
.
squeeze
(
target_masks_for_fg
,
1
,
'sampled_fg_mask_targets'
)
mrcnn_loss
=
maskrcnn_loss
(
mask_logits
,
fg_labels
,
target_masks_for_fg
)
else
:
mrcnn_loss
=
0.0
wd_cost
=
regularize_cost
(
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W'
,
...
...
@@ -420,6 +427,14 @@ class ResNetFPNModel(DetectionModel):
else
:
final_boxes
,
final_labels
=
self
.
fastrcnn_inference
(
image_shape2d
,
rcnn_boxes
,
fastrcnn_label_logits
,
fastrcnn_box_logits
)
if
config
.
MODE_MASK
:
roi_feature_maskrcnn
=
multilevel_roi_align
(
p23456
[:
4
],
final_boxes
,
14
)
mask_logits
=
maskrcnn_upXconv_head
(
'maskrcnn'
,
roi_feature_maskrcnn
,
config
.
NUM_CLASS
,
4
)
# #fg x #cat x 28 x 28
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
final_labels
)),
tf
.
to_int32
(
final_labels
)
-
1
],
axis
=
1
)
final_mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #resultx28x28
final_masks
=
tf
.
sigmoid
(
final_mask_logits
,
name
=
'final_masks'
)
def
visualize
(
model_path
,
nr_visualize
=
50
,
output_dir
=
'output'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment