Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c712e8dd
Commit
c712e8dd
authored
Jul 12, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] support visualization for FPN
parent
37530d96
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
14 deletions
+21
-14
examples/FasterRCNN/model_fpn.py
examples/FasterRCNN/model_fpn.py
+6
-3
examples/FasterRCNN/model_rpn.py
examples/FasterRCNN/model_rpn.py
+4
-4
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+10
-7
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+1
-0
No files found.
examples/FasterRCNN/model_fpn.py
View file @
c712e8dd
...
...
@@ -162,6 +162,7 @@ def multilevel_rpn_losses(
return
total_label_loss
,
total_box_loss
@
under_name_scope
()
def
generate_fpn_proposals
(
multilevel_anchors
,
multilevel_label_logits
,
multilevel_box_logits
,
image_shape2d
):
...
...
@@ -186,7 +187,7 @@ def generate_fpn_proposals(
if
cfg
.
FPN
.
PROPOSAL_MODE
==
'Level'
:
fpn_nms_topk
=
cfg
.
RPN
.
TRAIN_PER_LEVEL_NMS_TOPK
if
ctx
.
is_training
else
cfg
.
RPN
.
TEST_PER_LEVEL_NMS_TOPK
for
lvl
in
range
(
num_lvl
):
with
tf
.
name_scope
(
'
FPNProposal_
Lvl{}'
.
format
(
lvl
+
2
)):
with
tf
.
name_scope
(
'Lvl{}'
.
format
(
lvl
+
2
)):
anchors
=
multilevel_anchors
[
lvl
]
pred_boxes_decoded
=
anchors
.
decode_logits
(
multilevel_box_logits
[
lvl
])
...
...
@@ -204,7 +205,7 @@ def generate_fpn_proposals(
proposal_boxes
=
tf
.
gather
(
proposal_boxes
,
topk_indices
)
else
:
for
lvl
in
range
(
num_lvl
):
with
tf
.
name_scope
(
'
FPNProposal_
Lvl{}'
.
format
(
lvl
+
2
)):
with
tf
.
name_scope
(
'Lvl{}'
.
format
(
lvl
+
2
)):
anchors
=
multilevel_anchors
[
lvl
]
pred_boxes_decoded
=
anchors
.
decode_logits
(
multilevel_box_logits
[
lvl
])
all_boxes
.
append
(
tf
.
reshape
(
pred_boxes_decoded
,
[
-
1
,
4
]))
...
...
@@ -216,4 +217,6 @@ def generate_fpn_proposals(
cfg
.
RPN
.
TRAIN_PRE_NMS_TOPK
if
ctx
.
is_training
else
cfg
.
RPN
.
TEST_PRE_NMS_TOPK
,
cfg
.
RPN
.
TRAIN_POST_NMS_TOPK
if
ctx
.
is_training
else
cfg
.
RPN
.
TEST_POST_NMS_TOPK
)
return
proposal_boxes
,
proposal_scores
tf
.
sigmoid
(
proposal_scores
,
name
=
'probs'
)
# for visualization
return
tf
.
identity
(
proposal_boxes
,
name
=
'boxes'
),
\
tf
.
identity
(
proposal_scores
,
name
=
'scores'
)
examples/FasterRCNN/model_rpn.py
View file @
c712e8dd
...
...
@@ -148,7 +148,7 @@ def generate_rpn_proposals(boxes, scores, img_shape,
iou_threshold
=
cfg
.
RPN
.
PROPOSAL_NMS_THRESH
)
topk_valid_boxes
=
tf
.
reshape
(
topk_valid_boxes_x1y1x2y2
,
(
-
1
,
4
))
fin
al_boxes
=
tf
.
gather
(
topk_valid_boxes
,
nms_indices
)
fin
al_scores
=
tf
.
gather
(
topk_valid_scores
,
nms_indices
)
tf
.
sigmoid
(
fin
al_scores
,
name
=
'probs'
)
# for visualization
return
tf
.
stop_gradient
(
final_boxes
,
name
=
'boxes'
),
tf
.
stop_gradient
(
fin
al_scores
,
name
=
'scores'
)
propos
al_boxes
=
tf
.
gather
(
topk_valid_boxes
,
nms_indices
)
propos
al_scores
=
tf
.
gather
(
topk_valid_scores
,
nms_indices
)
tf
.
sigmoid
(
propos
al_scores
,
name
=
'probs'
)
# for visualization
return
tf
.
stop_gradient
(
proposal_boxes
,
name
=
'boxes'
),
tf
.
stop_gradient
(
propos
al_scores
,
name
=
'scores'
)
examples/FasterRCNN/train.py
View file @
c712e8dd
...
...
@@ -396,7 +396,7 @@ class ResNetFPNModel(DetectionModel):
tf
.
sigmoid
(
final_mask_logits
,
name
=
'final_masks'
)
def
visualize
(
model
_path
,
nr_visualize
=
5
0
,
output_dir
=
'output'
):
def
visualize
(
model
,
model_path
,
nr_visualize
=
10
0
,
output_dir
=
'output'
):
"""
Visualize some intermediate results (proposals, raw predictions) inside the pipeline.
Does not support FPN.
...
...
@@ -405,12 +405,12 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
df
.
reset_state
()
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
ResNetC4Model
()
,
model
=
model
,
session_init
=
get_model_loader
(
model_path
),
input_names
=
[
'image'
,
'gt_boxes'
,
'gt_labels'
],
output_names
=
[
'generate_
rpn_proposals/boxes'
,
'generate_
rpn_proposals/probs'
,
'generate_
{}_proposals/boxes'
.
format
(
'fpn'
if
cfg
.
MODE_FPN
else
'rpn'
)
,
'generate_
{}_proposals/probs'
.
format
(
'fpn'
if
cfg
.
MODE_FPN
else
'rpn'
)
,
'fastrcnn_all_probs'
,
'final_boxes'
,
'final_probs'
,
...
...
@@ -422,7 +422,11 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
utils
.
fs
.
mkdir_p
(
output_dir
)
with
tqdm
.
tqdm
(
total
=
nr_visualize
)
as
pbar
:
for
idx
,
dp
in
itertools
.
islice
(
enumerate
(
df
.
get_data
()),
nr_visualize
):
img
,
_
,
_
,
gt_boxes
,
gt_labels
=
dp
img
=
dp
[
0
]
if
cfg
.
MODE_MASK
:
gt_boxes
,
gt_labels
,
gt_masks
=
dp
[
-
3
:]
else
:
gt_boxes
,
gt_labels
=
dp
[
-
2
:]
rpn_boxes
,
rpn_scores
,
all_probs
,
\
final_boxes
,
final_probs
,
final_labels
=
pred
(
img
,
gt_boxes
,
gt_labels
)
...
...
@@ -530,8 +534,7 @@ if __name__ == '__main__':
cfg
.
TEST
.
RESULT_SCORE_THRESH
=
cfg
.
TEST
.
RESULT_SCORE_THRESH_VIS
if
args
.
visualize
:
assert
not
cfg
.
MODE_FPN
,
"FPN visualize is not supported!"
visualize
(
args
.
load
)
visualize
(
MODEL
,
args
.
load
)
else
:
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
MODEL
,
...
...
tensorpack/callbacks/saver.py
View file @
c712e8dd
...
...
@@ -26,6 +26,7 @@ class ModelSaver(Callback):
Args:
max_to_keep (int): the same as in ``tf.train.Saver``.
keep_checkpoint_every_n_hours (float): the same as in ``tf.train.Saver``.
Note that "keep" does not mean "create", but means "don't delete".
checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``.
var_collections (str or list of str): collection of the variables (or list of collections) to save.
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment