Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4db43fec
Commit
4db43fec
authored
Nov 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FasterRCNN] predict boxes inside the graph, and improve performance
parent
3a4fdf64
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
117 deletions
+121
-117
examples/FasterRCNN/README.md
examples/FasterRCNN/README.md
+4
-4
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+20
-60
examples/FasterRCNN/model.py
examples/FasterRCNN/model.py
+62
-23
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+27
-20
examples/FasterRCNN/viz.py
examples/FasterRCNN/viz.py
+8
-10
No files found.
examples/FasterRCNN/README.md
View file @
4db43fec
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
This example aims to provide a minimal (<1000 lines) multi-GPU implementation of ResNet50-Faster-RCNN on COCO.
This example aims to provide a minimal (<1000 lines) multi-GPU implementation of ResNet50-Faster-RCNN on COCO.
## Dependencies
## Dependencies
+
TensorFlow >
= 1.4.0rc0
+
TensorFlow >
1.4.0 (use tf-nightly-gpu for now)
+
Install
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
+
Install
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
+
Pre-trained
[
ResNet50 model
](
https://goo.gl/6XjK9V
)
from tensorpack model zoo.
+
Pre-trained
[
ResNet50 model
](
https://goo.gl/6XjK9V
)
from tensorpack model zoo.
+
COCO data. It assumes the following directory structure:
+
COCO data. It assumes the following directory structure:
...
@@ -46,10 +46,10 @@ To evaluate the performance (pretrained models can be downloaded in [model zoo](
...
@@ -46,10 +46,10 @@ To evaluate the performance (pretrained models can be downloaded in [model zoo](
Mean Average Precision @IoU=0.50:0.95:
Mean Average Precision @IoU=0.50:0.95:
+
trainval35k/minival, FASTRCNN_BATCH=256: 3
3.7
. Takes 49h on 8 TitanX.
+
trainval35k/minival, FASTRCNN_BATCH=256: 3
4.2
. Takes 49h on 8 TitanX.
+
trainval35k/minival, FASTRCNN_BATCH=64: 32.
2
. Takes 31h on 8 TitanX.
+
trainval35k/minival, FASTRCNN_BATCH=64: 32.
7
. Takes 31h on 8 TitanX.
The hyperparameters are not carefully tuned. You can probably get better performance by e.g.
training longer.
The hyperparameters are not carefully tuned. You can probably get better performance by e.g. training longer.
## Notes
## Notes
...
...
examples/FasterRCNN/eval.py
View file @
4db43fec
...
@@ -5,13 +5,13 @@
...
@@ -5,13 +5,13 @@
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
tqdm
import
cv2
import
cv2
import
six
import
os
import
os
from
collections
import
namedtuple
from
collections
import
namedtuple
,
defaultdict
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorpack.dataflow
import
MapDataComponent
,
TestDataSpeed
from
tensorpack.dataflow
import
MapDataComponent
,
TestDataSpeed
from
tensorpack.tfutils
import
get_default_sess_config
from
tensorpack.tfutils
import
get_default_sess_config
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.utils
import
get_tqdm_kwargs
from
tensorpack.utils.utils
import
get_tqdm_kwargs
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
...
@@ -26,59 +26,6 @@ DetectionResult = namedtuple(
...
@@ -26,59 +26,6 @@ DetectionResult = namedtuple(
[
'class_id'
,
'boxes'
,
'scores'
])
[
'class_id'
,
'boxes'
,
'scores'
])
@
memoized
def
get_tf_nms
(
num_output
,
thresh
):
"""
Get a NMS callable.
"""
# create a new graph for it
with
tf
.
Graph
()
.
as_default
(),
tf
.
device
(
'/cpu:0'
):
boxes
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
scores
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
])
indices
=
tf
.
image
.
non_max_suppression
(
boxes
,
scores
,
num_output
,
thresh
)
sess
=
tf
.
Session
(
config
=
get_default_sess_config
())
return
sess
.
make_callable
(
indices
,
[
boxes
,
scores
])
def
nms_fastrcnn_results
(
boxes
,
probs
):
"""
Args:
boxes: nx#catx4 floatbox in float32
probs: nx#class
Returns:
[DetectionResult]
"""
C
=
probs
.
shape
[
1
]
boxes
=
boxes
.
copy
()
nms_func
=
get_tf_nms
(
config
.
RESULTS_PER_IM
,
config
.
FASTRCNN_NMS_THRESH
)
ret
=
[]
for
klass
in
range
(
1
,
C
):
ids
=
np
.
where
(
probs
[:,
klass
]
>
config
.
RESULT_SCORE_THRESH
)[
0
]
if
ids
.
size
==
0
:
continue
probs_k
=
probs
[
ids
,
klass
]
.
flatten
()
boxes_k
=
boxes
[
ids
,
klass
-
1
,
:]
selected_ids
=
nms_func
(
boxes_k
,
probs_k
)
selected_boxes
=
boxes_k
[
selected_ids
,
:]
.
copy
()
ret
.
append
(
DetectionResult
(
klass
,
selected_boxes
,
probs_k
[
selected_ids
]))
if
len
(
ret
):
newret
=
[]
all_scores
=
np
.
hstack
([
x
.
scores
for
x
in
ret
])
if
len
(
all_scores
)
>
config
.
RESULTS_PER_IM
:
score_thresh
=
np
.
sort
(
all_scores
)[
-
config
.
RESULTS_PER_IM
]
for
klass
,
boxes
,
scores
in
ret
:
keep_ids
=
np
.
where
(
scores
>=
score_thresh
)[
0
]
if
len
(
keep_ids
):
newret
.
append
(
DetectionResult
(
klass
,
boxes
[
keep_ids
,
:],
scores
[
keep_ids
]))
ret
=
newret
return
ret
def
detect_one_image
(
img
,
model_func
):
def
detect_one_image
(
img
,
model_func
):
"""
"""
Run detection on one image, using the TF callable.
Run detection on one image, using the TF callable.
...
@@ -91,20 +38,33 @@ def detect_one_image(img, model_func):
...
@@ -91,20 +38,33 @@ def detect_one_image(img, model_func):
Returns:
Returns:
[DetectionResult]
[DetectionResult]
"""
"""
def
group_results_by_class
(
boxes
,
probs
,
labels
):
dic
=
defaultdict
(
list
)
for
box
,
prob
,
lab
in
zip
(
boxes
,
probs
,
labels
):
dic
[
lab
]
.
append
((
box
,
prob
))
def
mapf
(
lab
,
values
):
boxes
=
np
.
asarray
([
k
[
0
]
for
k
in
values
])
probs
=
np
.
asarray
([
k
[
1
]
for
k
in
values
])
return
DetectionResult
(
lab
,
boxes
,
probs
)
return
[
mapf
(
k
,
v
)
for
k
,
v
in
six
.
iteritems
(
dic
)]
resizer
=
CustomResize
(
config
.
SHORT_EDGE_SIZE
,
config
.
MAX_SIZE
)
resizer
=
CustomResize
(
config
.
SHORT_EDGE_SIZE
,
config
.
MAX_SIZE
)
resized_img
=
resizer
.
augment
(
img
)
resized_img
=
resizer
.
augment
(
img
)
scale
=
(
resized_img
.
shape
[
0
]
*
1.0
/
img
.
shape
[
0
]
+
resized_img
.
shape
[
1
]
*
1.0
/
img
.
shape
[
1
])
/
2
scale
=
(
resized_img
.
shape
[
0
]
*
1.0
/
img
.
shape
[
0
]
+
resized_img
.
shape
[
1
]
*
1.0
/
img
.
shape
[
1
])
/
2
fg_probs
,
fg_boxe
s
=
model_func
(
resized_img
)
boxes
,
probs
,
label
s
=
model_func
(
resized_img
)
fg_boxes
=
fg_
boxes
/
scale
boxes
=
boxes
/
scale
fg_boxes
=
clip_boxes
(
fg_
boxes
,
img
.
shape
[:
2
])
boxes
=
clip_boxes
(
boxes
,
img
.
shape
[:
2
])
return
nms_fastrcnn_results
(
fg_boxes
,
fg_prob
s
)
return
group_results_by_class
(
boxes
,
probs
,
label
s
)
def
eval_on_dataflow
(
df
,
detect_func
):
def
eval_on_dataflow
(
df
,
detect_func
):
"""
"""
Args:
Args:
df: a DataFlow which produces (image, image_id)
df: a DataFlow which produces (image, image_id)
detect_func: a callable, takes [image] and returns
a dict
detect_func: a callable, takes [image] and returns
[DetectionResult]
Returns:
Returns:
list of dict, to be dumped to COCO json format
list of dict, to be dumped to COCO json format
...
...
examples/FasterRCNN/model.py
View file @
4db43fec
...
@@ -254,13 +254,13 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
...
@@ -254,13 +254,13 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
def
sample_fg_bg
(
iou
):
def
sample_fg_bg
(
iou
):
fg_mask
=
tf
.
reduce_max
(
iou
,
axis
=
1
)
>=
config
.
FASTRCNN_FG_THRESH
fg_mask
=
tf
.
reduce_max
(
iou
,
axis
=
1
)
>=
config
.
FASTRCNN_FG_THRESH
fg_inds
=
tf
.
where
(
fg_mask
)[:,
0
]
fg_inds
=
tf
.
reshape
(
tf
.
where
(
fg_mask
),
[
-
1
])
num_fg
=
tf
.
minimum
(
int
(
num_fg
=
tf
.
minimum
(
int
(
config
.
FASTRCNN_BATCH_PER_IM
*
config
.
FASTRCNN_FG_RATIO
),
config
.
FASTRCNN_BATCH_PER_IM
*
config
.
FASTRCNN_FG_RATIO
),
tf
.
size
(
fg_inds
),
name
=
'num_fg'
)
tf
.
size
(
fg_inds
),
name
=
'num_fg'
)
fg_inds
=
tf
.
random_shuffle
(
fg_inds
)[:
num_fg
]
fg_inds
=
tf
.
random_shuffle
(
fg_inds
)[:
num_fg
]
bg_inds
=
tf
.
where
(
tf
.
logical_not
(
fg_mask
))[:,
0
]
bg_inds
=
tf
.
reshape
(
tf
.
where
(
tf
.
logical_not
(
fg_mask
)),
[
-
1
])
num_bg
=
tf
.
minimum
(
num_bg
=
tf
.
minimum
(
config
.
FASTRCNN_BATCH_PER_IM
-
num_fg
,
config
.
FASTRCNN_BATCH_PER_IM
-
num_fg
,
tf
.
size
(
bg_inds
),
name
=
'num_bg'
)
tf
.
size
(
bg_inds
),
name
=
'num_bg'
)
...
@@ -383,27 +383,6 @@ def fastrcnn_head(feature, num_classes):
...
@@ -383,27 +383,6 @@ def fastrcnn_head(feature, num_classes):
return
classification
,
box_regression
return
classification
,
box_regression
@
under_name_scope
()
def
fastrcnn_predict_boxes
(
labels
,
box_logits
):
"""
Args:
labels: n,
box_logits: nx(C-1)x4
Returns:
fg_ind: fg, indices into n
fg_box_logits: fgx4
"""
fg_ind
=
tf
.
reshape
(
tf
.
where
(
labels
>
0
),
[
-
1
])
# nfg,
fg_labels
=
tf
.
gather
(
labels
,
fg_ind
)
# nfg,
ind_2d
=
tf
.
stack
([
fg_ind
,
fg_labels
-
1
],
axis
=
1
)
# nfgx2
# n x c-1 x 4 -> nfgx4
fg_box_logits
=
tf
.
gather_nd
(
box_logits
,
tf
.
stop_gradient
(
ind_2d
))
return
fg_ind
,
fg_box_logits
@
under_name_scope
()
@
under_name_scope
()
def
fastrcnn_losses
(
labels
,
label_logits
,
fg_boxes
,
fg_box_logits
):
def
fastrcnn_losses
(
labels
,
label_logits
,
fg_boxes
,
fg_box_logits
):
"""
"""
...
@@ -442,3 +421,63 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
...
@@ -442,3 +421,63 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
for
k
in
[
label_loss
,
box_loss
,
accuracy
,
fg_accuracy
,
false_negative
]:
for
k
in
[
label_loss
,
box_loss
,
accuracy
,
fg_accuracy
,
false_negative
]:
add_moving_summary
(
k
)
add_moving_summary
(
k
)
return
label_loss
,
box_loss
return
label_loss
,
box_loss
@
under_name_scope
()
def
fastrcnn_predictions
(
boxes
,
probs
):
"""
Generate final results from predictions of all proposals.
Args:
boxes: n#catx4 floatbox in float32
probs: nx#class
"""
assert
boxes
.
shape
[
1
]
==
config
.
NUM_CLASS
-
1
assert
probs
.
shape
[
1
]
==
config
.
NUM_CLASS
N
=
tf
.
shape
(
boxes
)[
0
]
boxes
=
tf
.
transpose
(
boxes
,
[
1
,
0
,
2
])
# #catxnx4
probs
=
tf
.
transpose
(
probs
[:,
1
:],
[
1
,
0
])
# #catxn
def
f
(
X
):
"""
prob: n probabilities
box: nx4 boxes
Returns: n boolean, the selection
"""
prob
,
box
=
X
output_shape
=
tf
.
shape
(
prob
)
# filter by score threshold
ids
=
tf
.
reshape
(
tf
.
where
(
prob
>
config
.
RESULT_SCORE_THRESH
),
[
-
1
])
prob
=
tf
.
gather
(
prob
,
ids
)
box
=
tf
.
gather
(
box
,
ids
)
# NMS within each class
selection
=
tf
.
image
.
non_max_suppression
(
box
,
prob
,
config
.
RESULTS_PER_IM
,
config
.
FASTRCNN_NMS_THRESH
)
selection
=
tf
.
to_int32
(
tf
.
gather
(
ids
,
selection
))
# sort available in TF>1.4.0
# selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection
,
_
=
tf
.
nn
.
top_k
(
-
selection
,
k
=
tf
.
size
(
selection
))
mask
=
tf
.
sparse_to_dense
(
sparse_indices
=-
sorted_selection
,
output_shape
=
output_shape
,
sparse_values
=
True
,
default_value
=
False
)
return
mask
masks
=
tf
.
map_fn
(
f
,
(
probs
,
boxes
),
dtype
=
tf
.
bool
,
parallel_iterations
=
10
)
# #cat x N
selected_indices
=
tf
.
where
(
masks
)
# #selection x 2, each is (cat_id, box_id)
boxes
=
tf
.
boolean_mask
(
boxes
,
masks
)
# #selection x 4
probs
=
tf
.
boolean_mask
(
probs
,
masks
)
labels
=
selected_indices
[:,
0
]
+
1
# filter again by sorting scores
topk_probs
,
topk_indices
=
tf
.
nn
.
top_k
(
probs
,
tf
.
minimum
(
config
.
RESULTS_PER_IM
,
tf
.
size
(
probs
)),
sorted
=
False
)
topk_probs
=
tf
.
identity
(
topk_probs
,
name
=
'probs'
)
topk_boxes
=
tf
.
gather
(
boxes
,
topk_indices
,
name
=
'boxes'
)
topk_labels
=
tf
.
gather
(
labels
,
topk_indices
,
name
=
'labels'
)
return
topk_boxes
,
topk_probs
,
topk_labels
examples/FasterRCNN/train.py
View file @
4db43fec
...
@@ -27,10 +27,10 @@ from coco import COCODetection
...
@@ -27,10 +27,10 @@ from coco import COCODetection
from
basemodel
import
(
from
basemodel
import
(
image_preprocess
,
pretrained_resnet_conv4
,
resnet_conv5
)
image_preprocess
,
pretrained_resnet_conv4
,
resnet_conv5
)
from
model
import
(
from
model
import
(
rpn_head
,
rpn_losses
,
decode_bbox_target
,
encode_bbox_target
,
decode_bbox_target
,
encode_bbox_target
,
rpn_head
,
rpn_losses
,
generate_rpn_proposals
,
sample_fast_rcnn_targets
,
generate_rpn_proposals
,
sample_fast_rcnn_targets
,
roi_align
,
fastrcnn_head
,
fastrcnn_losses
,
fastrcnn_predict
_boxe
s
)
roi_align
,
fastrcnn_head
,
fastrcnn_losses
,
fastrcnn_predict
ion
s
)
from
data
import
(
from
data
import
(
get_train_dataflow
,
get_eval_dataflow
,
get_train_dataflow
,
get_eval_dataflow
,
get_all_anchors
)
get_all_anchors
)
...
@@ -39,8 +39,7 @@ from viz import (
...
@@ -39,8 +39,7 @@ from viz import (
draw_predictions
,
draw_final_outputs
)
draw_predictions
,
draw_final_outputs
)
from
common
import
clip_boxes
,
CustomResize
,
print_config
from
common
import
clip_boxes
,
CustomResize
,
print_config
from
eval
import
(
from
eval
import
(
eval_on_dataflow
,
detect_one_image
,
print_evaluation_scores
,
eval_on_dataflow
,
detect_one_image
,
print_evaluation_scores
)
nms_fastrcnn_results
)
import
config
import
config
...
@@ -112,7 +111,7 @@ class Model(ModelDesc):
...
@@ -112,7 +111,7 @@ class Model(ModelDesc):
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
'fastrcnn'
,
feature_fastrcnn
,
config
.
NUM_CLASS
)
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
'fastrcnn'
,
feature_fastrcnn
,
config
.
NUM_CLASS
)
if
is_training
:
if
is_training
:
fg_inds_wrt_sample
=
tf
.
where
(
rcnn_labels
>
0
)[:,
0
]
# fg inds w.r.t all samples
fg_inds_wrt_sample
=
tf
.
reshape
(
tf
.
where
(
rcnn_labels
>
0
),
[
-
1
])
# fg inds w.r.t all samples
fg_sampled_boxes
=
tf
.
gather
(
rcnn_sampled_boxes
,
fg_inds_wrt_sample
)
fg_sampled_boxes
=
tf
.
gather
(
rcnn_sampled_boxes
,
fg_inds_wrt_sample
)
matched_gt_boxes
=
tf
.
gather
(
gt_boxes
,
fg_inds_wrt_gt
)
matched_gt_boxes
=
tf
.
gather
(
gt_boxes
,
fg_inds_wrt_gt
)
...
@@ -143,6 +142,8 @@ class Model(ModelDesc):
...
@@ -143,6 +142,8 @@ class Model(ModelDesc):
tf
.
constant
(
config
.
FASTRCNN_BBOX_REG_WEIGHTS
),
anchors
)
tf
.
constant
(
config
.
FASTRCNN_BBOX_REG_WEIGHTS
),
anchors
)
decoded_boxes
=
tf
.
identity
(
decoded_boxes
,
name
=
'fastrcnn_all_boxes'
)
decoded_boxes
=
tf
.
identity
(
decoded_boxes
,
name
=
'fastrcnn_all_boxes'
)
pred_boxes
,
pred_probs
,
pred_labels
=
fastrcnn_predictions
(
decoded_boxes
,
label_probs
)
def
_get_optimizer
(
self
):
def
_get_optimizer
(
self
):
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
0.003
,
trainable
=
False
)
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
0.003
,
trainable
=
False
)
tf
.
summary
.
scalar
(
'learning_rate'
,
lr
)
tf
.
summary
.
scalar
(
'learning_rate'
,
lr
)
...
@@ -166,8 +167,9 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
...
@@ -166,8 +167,9 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
'generate_rpn_proposals/boxes'
,
'generate_rpn_proposals/boxes'
,
'generate_rpn_proposals/probs'
,
'generate_rpn_proposals/probs'
,
'fastrcnn_all_probs'
,
'fastrcnn_all_probs'
,
'fastrcnn_fg_probs'
,
'fastrcnn_predictions/boxes'
,
'fastrcnn_fg_boxes'
,
'fastrcnn_predictions/probs'
,
'fastrcnn_predictions/labels'
,
]))
]))
df
=
get_train_dataflow
()
df
=
get_train_dataflow
()
df
.
reset_state
()
df
.
reset_state
()
...
@@ -179,21 +181,21 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
...
@@ -179,21 +181,21 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
for
idx
,
dp
in
itertools
.
islice
(
enumerate
(
df
.
get_data
()),
nr_visualize
):
for
idx
,
dp
in
itertools
.
islice
(
enumerate
(
df
.
get_data
()),
nr_visualize
):
img
,
_
,
_
,
gt_boxes
,
gt_labels
=
dp
img
,
_
,
_
,
gt_boxes
,
gt_labels
=
dp
rpn_boxes
,
rpn_scores
,
all_probs
,
fg_probs
,
fg_boxes
=
pred
(
img
,
gt_boxes
,
gt_labels
)
rpn_boxes
,
rpn_scores
,
all_probs
,
\
final_boxes
,
final_probs
,
final_labels
=
pred
(
img
,
gt_boxes
,
gt_labels
)
# draw groundtruth boxes
gt_viz
=
draw_annotation
(
img
,
gt_boxes
,
gt_labels
)
gt_viz
=
draw_annotation
(
img
,
gt_boxes
,
gt_labels
)
# draw best proposals for each groundtruth, to show recall
proposal_viz
,
good_proposals_ind
=
draw_proposal_recall
(
img
,
rpn_boxes
,
rpn_scores
,
gt_boxes
)
proposal_viz
,
good_proposals_ind
=
draw_proposal_recall
(
img
,
rpn_boxes
,
rpn_scores
,
gt_boxes
)
# draw the scores for the above proposals
score_viz
=
draw_predictions
(
img
,
rpn_boxes
[
good_proposals_ind
],
all_probs
[
good_proposals_ind
])
score_viz
=
draw_predictions
(
img
,
rpn_boxes
[
good_proposals_ind
],
all_probs
[
good_proposals_ind
])
fg_boxes
=
clip_boxes
(
fg_boxes
,
img
.
shape
[:
2
])
final_viz
=
draw_final_outputs
(
img
,
final_boxes
,
final_probs
,
final_labels
)
fg_viz
=
draw_predictions
(
img
,
fg_boxes
,
fg_probs
)
results
=
nms_fastrcnn_results
(
fg_boxes
,
fg_probs
)
final_viz
=
draw_final_outputs
(
img
,
results
)
viz
=
tpviz
.
stack_patches
([
viz
=
tpviz
.
stack_patches
([
gt_viz
,
proposal_viz
,
score_viz
,
gt_viz
,
proposal_viz
,
fg_viz
,
final_viz
],
2
,
3
)
score_viz
,
final_viz
],
2
,
2
)
if
os
.
environ
.
get
(
'DISPLAY'
,
None
):
if
os
.
environ
.
get
(
'DISPLAY'
,
None
):
tpviz
.
interactive_imshow
(
viz
)
tpviz
.
interactive_imshow
(
viz
)
...
@@ -207,8 +209,9 @@ def offline_evaluate(model_path, output_file):
...
@@ -207,8 +209,9 @@ def offline_evaluate(model_path, output_file):
session_init
=
get_model_loader
(
model_path
),
session_init
=
get_model_loader
(
model_path
),
input_names
=
[
'image'
],
input_names
=
[
'image'
],
output_names
=
[
output_names
=
[
'fastrcnn_all_probs'
,
'fastrcnn_predictions/boxes'
,
'fastrcnn_all_boxes'
,
'fastrcnn_predictions/probs'
,
'fastrcnn_predictions/labels'
,
]))
]))
df
=
get_eval_dataflow
()
df
=
get_eval_dataflow
()
df
=
PrefetchDataZMQ
(
df
,
1
)
df
=
PrefetchDataZMQ
(
df
,
1
)
...
@@ -224,8 +227,9 @@ def predict(model_path, input_file):
...
@@ -224,8 +227,9 @@ def predict(model_path, input_file):
session_init
=
get_model_loader
(
model_path
),
session_init
=
get_model_loader
(
model_path
),
input_names
=
[
'image'
],
input_names
=
[
'image'
],
output_names
=
[
output_names
=
[
'fastrcnn_all_probs'
,
'fastrcnn_predictions/boxes'
,
'fastrcnn_all_boxes'
,
'fastrcnn_predictions/probs'
,
'fastrcnn_predictions/labels'
,
]))
]))
img
=
cv2
.
imread
(
input_file
,
cv2
.
IMREAD_COLOR
)
img
=
cv2
.
imread
(
input_file
,
cv2
.
IMREAD_COLOR
)
results
=
detect_one_image
(
img
,
pred
)
results
=
detect_one_image
(
img
,
pred
)
...
@@ -237,7 +241,10 @@ def predict(model_path, input_file):
...
@@ -237,7 +241,10 @@ def predict(model_path, input_file):
class
EvalCallback
(
Callback
):
class
EvalCallback
(
Callback
):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
pred
=
self
.
trainer
.
get_predictor
(
self
.
pred
=
self
.
trainer
.
get_predictor
(
[
'image'
],
[
'fastrcnn_all_probs'
,
'fastrcnn_all_boxes'
])
[
'image'
],
[
'fastrcnn_predictions/boxes'
,
'fastrcnn_predictions/probs'
,
'fastrcnn_predictions/labels'
])
self
.
df
=
PrefetchDataZMQ
(
get_eval_dataflow
(),
1
)
self
.
df
=
PrefetchDataZMQ
(
get_eval_dataflow
(),
1
)
def
_before_train
(
self
):
def
_before_train
(
self
):
...
...
examples/FasterRCNN/viz.py
View file @
4db43fec
...
@@ -63,21 +63,19 @@ def draw_predictions(img, boxes, scores):
...
@@ -63,21 +63,19 @@ def draw_predictions(img, boxes, scores):
return
viz
.
draw_boxes
(
img
,
boxes
,
tags
)
return
viz
.
draw_boxes
(
img
,
boxes
,
tags
)
def
draw_final_outputs
(
img
,
result
s
):
def
draw_final_outputs
(
img
,
final_boxes
,
final_probs
,
final_label
s
):
"""
"""
Args:
Args:
results: [DetectionResult]
results: [DetectionResult]
"""
"""
all_boxes
=
[]
if
final_boxes
.
shape
[
0
]
==
0
:
all_tags
=
[]
for
class_id
,
boxes
,
scores
in
results
:
all_boxes
.
extend
(
boxes
)
all_tags
.
extend
(
[
"{},{:.2f}"
.
format
(
COCOMeta
.
class_names
[
class_id
],
sc
)
for
sc
in
scores
])
all_boxes
=
np
.
asarray
(
all_boxes
)
if
all_boxes
.
shape
[
0
]
==
0
:
return
img
return
img
return
viz
.
draw_boxes
(
img
,
all_boxes
,
all_tags
)
tags
=
[]
for
prob
,
label
in
zip
(
final_probs
,
final_labels
):
tags
.
append
(
"{},{:.2f}"
.
format
(
COCOMeta
.
class_names
[
label
],
prob
))
return
viz
.
draw_boxes
(
img
,
final_boxes
,
tags
)
def
draw_mask
(
im
,
mask
,
alpha
=
0.5
,
color
=
None
):
def
draw_mask
(
im
,
mask
,
alpha
=
0.5
,
color
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment