Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4db43fec
Commit
4db43fec
authored
Nov 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FasterRCNN] predict boxes inside the graph, and improve performance
parent
3a4fdf64
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
117 deletions
+121
-117
examples/FasterRCNN/README.md
examples/FasterRCNN/README.md
+4
-4
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+20
-60
examples/FasterRCNN/model.py
examples/FasterRCNN/model.py
+62
-23
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+27
-20
examples/FasterRCNN/viz.py
examples/FasterRCNN/viz.py
+8
-10
No files found.
examples/FasterRCNN/README.md
View file @
4db43fec
...
...
@@ -2,7 +2,7 @@
This example aims to provide a minimal (<1000 lines) multi-GPU implementation of ResNet50-Faster-RCNN on COCO.
## Dependencies
+
TensorFlow >
= 1.4.0rc0
+
TensorFlow >
1.4.0 (use tf-nightly-gpu for now)
+
Install
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
+
Pre-trained
[
ResNet50 model
](
https://goo.gl/6XjK9V
)
from tensorpack model zoo.
+
COCO data. It assumes the following directory structure:
...
...
@@ -46,8 +46,8 @@ To evaluate the performance (pretrained models can be downloaded in [model zoo](
Mean Average Precision @IoU=0.50:0.95:
+
trainval35k/minival, FASTRCNN_BATCH=256: 3
3.7
. Takes 49h on 8 TitanX.
+
trainval35k/minival, FASTRCNN_BATCH=64: 32.
2
. Takes 31h on 8 TitanX.
+
trainval35k/minival, FASTRCNN_BATCH=256: 3
4.2
. Takes 49h on 8 TitanX.
+
trainval35k/minival, FASTRCNN_BATCH=64: 32.
7
. Takes 31h on 8 TitanX.
The hyperparameters are not carefully tuned. You can probably get better performance by e.g. training longer.
...
...
examples/FasterRCNN/eval.py
View file @
4db43fec
...
...
@@ -5,13 +5,13 @@
import
numpy
as
np
import
tqdm
import
cv2
import
six
import
os
from
collections
import
namedtuple
from
collections
import
namedtuple
,
defaultdict
import
tensorflow
as
tf
from
tensorpack.dataflow
import
MapDataComponent
,
TestDataSpeed
from
tensorpack.tfutils
import
get_default_sess_config
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.utils
import
get_tqdm_kwargs
from
pycocotools.coco
import
COCO
...
...
@@ -26,59 +26,6 @@ DetectionResult = namedtuple(
[
'class_id'
,
'boxes'
,
'scores'
])
@
memoized
def
get_tf_nms
(
num_output
,
thresh
):
"""
Get a NMS callable.
"""
# create a new graph for it
with
tf
.
Graph
()
.
as_default
(),
tf
.
device
(
'/cpu:0'
):
boxes
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
scores
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
])
indices
=
tf
.
image
.
non_max_suppression
(
boxes
,
scores
,
num_output
,
thresh
)
sess
=
tf
.
Session
(
config
=
get_default_sess_config
())
return
sess
.
make_callable
(
indices
,
[
boxes
,
scores
])
def
nms_fastrcnn_results
(
boxes
,
probs
):
"""
Args:
boxes: nx#catx4 floatbox in float32
probs: nx#class
Returns:
[DetectionResult]
"""
C
=
probs
.
shape
[
1
]
boxes
=
boxes
.
copy
()
nms_func
=
get_tf_nms
(
config
.
RESULTS_PER_IM
,
config
.
FASTRCNN_NMS_THRESH
)
ret
=
[]
for
klass
in
range
(
1
,
C
):
ids
=
np
.
where
(
probs
[:,
klass
]
>
config
.
RESULT_SCORE_THRESH
)[
0
]
if
ids
.
size
==
0
:
continue
probs_k
=
probs
[
ids
,
klass
]
.
flatten
()
boxes_k
=
boxes
[
ids
,
klass
-
1
,
:]
selected_ids
=
nms_func
(
boxes_k
,
probs_k
)
selected_boxes
=
boxes_k
[
selected_ids
,
:]
.
copy
()
ret
.
append
(
DetectionResult
(
klass
,
selected_boxes
,
probs_k
[
selected_ids
]))
if
len
(
ret
):
newret
=
[]
all_scores
=
np
.
hstack
([
x
.
scores
for
x
in
ret
])
if
len
(
all_scores
)
>
config
.
RESULTS_PER_IM
:
score_thresh
=
np
.
sort
(
all_scores
)[
-
config
.
RESULTS_PER_IM
]
for
klass
,
boxes
,
scores
in
ret
:
keep_ids
=
np
.
where
(
scores
>=
score_thresh
)[
0
]
if
len
(
keep_ids
):
newret
.
append
(
DetectionResult
(
klass
,
boxes
[
keep_ids
,
:],
scores
[
keep_ids
]))
ret
=
newret
return
ret
def
detect_one_image
(
img
,
model_func
):
"""
Run detection on one image, using the TF callable.
...
...
@@ -91,20 +38,33 @@ def detect_one_image(img, model_func):
Returns:
[DetectionResult]
"""
def
group_results_by_class
(
boxes
,
probs
,
labels
):
dic
=
defaultdict
(
list
)
for
box
,
prob
,
lab
in
zip
(
boxes
,
probs
,
labels
):
dic
[
lab
]
.
append
((
box
,
prob
))
def
mapf
(
lab
,
values
):
boxes
=
np
.
asarray
([
k
[
0
]
for
k
in
values
])
probs
=
np
.
asarray
([
k
[
1
]
for
k
in
values
])
return
DetectionResult
(
lab
,
boxes
,
probs
)
return
[
mapf
(
k
,
v
)
for
k
,
v
in
six
.
iteritems
(
dic
)]
resizer
=
CustomResize
(
config
.
SHORT_EDGE_SIZE
,
config
.
MAX_SIZE
)
resized_img
=
resizer
.
augment
(
img
)
scale
=
(
resized_img
.
shape
[
0
]
*
1.0
/
img
.
shape
[
0
]
+
resized_img
.
shape
[
1
]
*
1.0
/
img
.
shape
[
1
])
/
2
fg_probs
,
fg_boxe
s
=
model_func
(
resized_img
)
fg_boxes
=
fg_
boxes
/
scale
fg_boxes
=
clip_boxes
(
fg_
boxes
,
img
.
shape
[:
2
])
return
nms_fastrcnn_results
(
fg_boxes
,
fg_prob
s
)
boxes
,
probs
,
label
s
=
model_func
(
resized_img
)
boxes
=
boxes
/
scale
boxes
=
clip_boxes
(
boxes
,
img
.
shape
[:
2
])
return
group_results_by_class
(
boxes
,
probs
,
label
s
)
def
eval_on_dataflow
(
df
,
detect_func
):
"""
Args:
df: a DataFlow which produces (image, image_id)
detect_func: a callable, takes [image] and returns
a dict
detect_func: a callable, takes [image] and returns
[DetectionResult]
Returns:
list of dict, to be dumped to COCO json format
...
...
examples/FasterRCNN/model.py
View file @
4db43fec
...
...
@@ -254,13 +254,13 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
def
sample_fg_bg
(
iou
):
fg_mask
=
tf
.
reduce_max
(
iou
,
axis
=
1
)
>=
config
.
FASTRCNN_FG_THRESH
fg_inds
=
tf
.
where
(
fg_mask
)[:,
0
]
fg_inds
=
tf
.
reshape
(
tf
.
where
(
fg_mask
),
[
-
1
])
num_fg
=
tf
.
minimum
(
int
(
config
.
FASTRCNN_BATCH_PER_IM
*
config
.
FASTRCNN_FG_RATIO
),
tf
.
size
(
fg_inds
),
name
=
'num_fg'
)
fg_inds
=
tf
.
random_shuffle
(
fg_inds
)[:
num_fg
]
bg_inds
=
tf
.
where
(
tf
.
logical_not
(
fg_mask
))[:,
0
]
bg_inds
=
tf
.
reshape
(
tf
.
where
(
tf
.
logical_not
(
fg_mask
)),
[
-
1
])
num_bg
=
tf
.
minimum
(
config
.
FASTRCNN_BATCH_PER_IM
-
num_fg
,
tf
.
size
(
bg_inds
),
name
=
'num_bg'
)
...
...
@@ -383,27 +383,6 @@ def fastrcnn_head(feature, num_classes):
return
classification
,
box_regression
@
under_name_scope
()
def
fastrcnn_predict_boxes
(
labels
,
box_logits
):
"""
Args:
labels: n,
box_logits: nx(C-1)x4
Returns:
fg_ind: fg, indices into n
fg_box_logits: fgx4
"""
fg_ind
=
tf
.
reshape
(
tf
.
where
(
labels
>
0
),
[
-
1
])
# nfg,
fg_labels
=
tf
.
gather
(
labels
,
fg_ind
)
# nfg,
ind_2d
=
tf
.
stack
([
fg_ind
,
fg_labels
-
1
],
axis
=
1
)
# nfgx2
# n x c-1 x 4 -> nfgx4
fg_box_logits
=
tf
.
gather_nd
(
box_logits
,
tf
.
stop_gradient
(
ind_2d
))
return
fg_ind
,
fg_box_logits
@
under_name_scope
()
def
fastrcnn_losses
(
labels
,
label_logits
,
fg_boxes
,
fg_box_logits
):
"""
...
...
@@ -442,3 +421,63 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
for
k
in
[
label_loss
,
box_loss
,
accuracy
,
fg_accuracy
,
false_negative
]:
add_moving_summary
(
k
)
return
label_loss
,
box_loss
@
under_name_scope
()
def
fastrcnn_predictions
(
boxes
,
probs
):
"""
Generate final results from predictions of all proposals.
Args:
boxes: n#catx4 floatbox in float32
probs: nx#class
"""
assert
boxes
.
shape
[
1
]
==
config
.
NUM_CLASS
-
1
assert
probs
.
shape
[
1
]
==
config
.
NUM_CLASS
N
=
tf
.
shape
(
boxes
)[
0
]
boxes
=
tf
.
transpose
(
boxes
,
[
1
,
0
,
2
])
# #catxnx4
probs
=
tf
.
transpose
(
probs
[:,
1
:],
[
1
,
0
])
# #catxn
def
f
(
X
):
"""
prob: n probabilities
box: nx4 boxes
Returns: n boolean, the selection
"""
prob
,
box
=
X
output_shape
=
tf
.
shape
(
prob
)
# filter by score threshold
ids
=
tf
.
reshape
(
tf
.
where
(
prob
>
config
.
RESULT_SCORE_THRESH
),
[
-
1
])
prob
=
tf
.
gather
(
prob
,
ids
)
box
=
tf
.
gather
(
box
,
ids
)
# NMS within each class
selection
=
tf
.
image
.
non_max_suppression
(
box
,
prob
,
config
.
RESULTS_PER_IM
,
config
.
FASTRCNN_NMS_THRESH
)
selection
=
tf
.
to_int32
(
tf
.
gather
(
ids
,
selection
))
# sort available in TF>1.4.0
# selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection
,
_
=
tf
.
nn
.
top_k
(
-
selection
,
k
=
tf
.
size
(
selection
))
mask
=
tf
.
sparse_to_dense
(
sparse_indices
=-
sorted_selection
,
output_shape
=
output_shape
,
sparse_values
=
True
,
default_value
=
False
)
return
mask
masks
=
tf
.
map_fn
(
f
,
(
probs
,
boxes
),
dtype
=
tf
.
bool
,
parallel_iterations
=
10
)
# #cat x N
selected_indices
=
tf
.
where
(
masks
)
# #selection x 2, each is (cat_id, box_id)
boxes
=
tf
.
boolean_mask
(
boxes
,
masks
)
# #selection x 4
probs
=
tf
.
boolean_mask
(
probs
,
masks
)
labels
=
selected_indices
[:,
0
]
+
1
# filter again by sorting scores
topk_probs
,
topk_indices
=
tf
.
nn
.
top_k
(
probs
,
tf
.
minimum
(
config
.
RESULTS_PER_IM
,
tf
.
size
(
probs
)),
sorted
=
False
)
topk_probs
=
tf
.
identity
(
topk_probs
,
name
=
'probs'
)
topk_boxes
=
tf
.
gather
(
boxes
,
topk_indices
,
name
=
'boxes'
)
topk_labels
=
tf
.
gather
(
labels
,
topk_indices
,
name
=
'labels'
)
return
topk_boxes
,
topk_probs
,
topk_labels
examples/FasterRCNN/train.py
View file @
4db43fec
...
...
@@ -27,10 +27,10 @@ from coco import COCODetection
from
basemodel
import
(
image_preprocess
,
pretrained_resnet_conv4
,
resnet_conv5
)
from
model
import
(
rpn_head
,
rpn_losses
,
decode_bbox_target
,
encode_bbox_target
,
rpn_head
,
rpn_losses
,
generate_rpn_proposals
,
sample_fast_rcnn_targets
,
roi_align
,
fastrcnn_head
,
fastrcnn_losses
,
fastrcnn_predict
_boxe
s
)
roi_align
,
fastrcnn_head
,
fastrcnn_losses
,
fastrcnn_predict
ion
s
)
from
data
import
(
get_train_dataflow
,
get_eval_dataflow
,
get_all_anchors
)
...
...
@@ -39,8 +39,7 @@ from viz import (
draw_predictions
,
draw_final_outputs
)
from
common
import
clip_boxes
,
CustomResize
,
print_config
from
eval
import
(
eval_on_dataflow
,
detect_one_image
,
print_evaluation_scores
,
nms_fastrcnn_results
)
eval_on_dataflow
,
detect_one_image
,
print_evaluation_scores
)
import
config
...
...
@@ -112,7 +111,7 @@ class Model(ModelDesc):
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
'fastrcnn'
,
feature_fastrcnn
,
config
.
NUM_CLASS
)
if
is_training
:
fg_inds_wrt_sample
=
tf
.
where
(
rcnn_labels
>
0
)[:,
0
]
# fg inds w.r.t all samples
fg_inds_wrt_sample
=
tf
.
reshape
(
tf
.
where
(
rcnn_labels
>
0
),
[
-
1
])
# fg inds w.r.t all samples
fg_sampled_boxes
=
tf
.
gather
(
rcnn_sampled_boxes
,
fg_inds_wrt_sample
)
matched_gt_boxes
=
tf
.
gather
(
gt_boxes
,
fg_inds_wrt_gt
)
...
...
@@ -143,6 +142,8 @@ class Model(ModelDesc):
tf
.
constant
(
config
.
FASTRCNN_BBOX_REG_WEIGHTS
),
anchors
)
decoded_boxes
=
tf
.
identity
(
decoded_boxes
,
name
=
'fastrcnn_all_boxes'
)
pred_boxes
,
pred_probs
,
pred_labels
=
fastrcnn_predictions
(
decoded_boxes
,
label_probs
)
def
_get_optimizer
(
self
):
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
0.003
,
trainable
=
False
)
tf
.
summary
.
scalar
(
'learning_rate'
,
lr
)
...
...
@@ -166,8 +167,9 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
'generate_rpn_proposals/boxes'
,
'generate_rpn_proposals/probs'
,
'fastrcnn_all_probs'
,
'fastrcnn_fg_probs'
,
'fastrcnn_fg_boxes'
,
'fastrcnn_predictions/boxes'
,
'fastrcnn_predictions/probs'
,
'fastrcnn_predictions/labels'
,
]))
df
=
get_train_dataflow
()
df
.
reset_state
()
...
...
@@ -179,21 +181,21 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
for
idx
,
dp
in
itertools
.
islice
(
enumerate
(
df
.
get_data
()),
nr_visualize
):
img
,
_
,
_
,
gt_boxes
,
gt_labels
=
dp
rpn_boxes
,
rpn_scores
,
all_probs
,
fg_probs
,
fg_boxes
=
pred
(
img
,
gt_boxes
,
gt_labels
)
rpn_boxes
,
rpn_scores
,
all_probs
,
\
final_boxes
,
final_probs
,
final_labels
=
pred
(
img
,
gt_boxes
,
gt_labels
)
# draw groundtruth boxes
gt_viz
=
draw_annotation
(
img
,
gt_boxes
,
gt_labels
)
# draw best proposals for each groundtruth, to show recall
proposal_viz
,
good_proposals_ind
=
draw_proposal_recall
(
img
,
rpn_boxes
,
rpn_scores
,
gt_boxes
)
# draw the scores for the above proposals
score_viz
=
draw_predictions
(
img
,
rpn_boxes
[
good_proposals_ind
],
all_probs
[
good_proposals_ind
])
fg_boxes
=
clip_boxes
(
fg_boxes
,
img
.
shape
[:
2
])
fg_viz
=
draw_predictions
(
img
,
fg_boxes
,
fg_probs
)
results
=
nms_fastrcnn_results
(
fg_boxes
,
fg_probs
)
final_viz
=
draw_final_outputs
(
img
,
results
)
final_viz
=
draw_final_outputs
(
img
,
final_boxes
,
final_probs
,
final_labels
)
viz
=
tpviz
.
stack_patches
([
gt_viz
,
proposal_viz
,
score_viz
,
fg_viz
,
final_viz
],
2
,
3
)
gt_viz
,
proposal_viz
,
score_viz
,
final_viz
],
2
,
2
)
if
os
.
environ
.
get
(
'DISPLAY'
,
None
):
tpviz
.
interactive_imshow
(
viz
)
...
...
@@ -207,8 +209,9 @@ def offline_evaluate(model_path, output_file):
session_init
=
get_model_loader
(
model_path
),
input_names
=
[
'image'
],
output_names
=
[
'fastrcnn_all_probs'
,
'fastrcnn_all_boxes'
,
'fastrcnn_predictions/boxes'
,
'fastrcnn_predictions/probs'
,
'fastrcnn_predictions/labels'
,
]))
df
=
get_eval_dataflow
()
df
=
PrefetchDataZMQ
(
df
,
1
)
...
...
@@ -224,8 +227,9 @@ def predict(model_path, input_file):
session_init
=
get_model_loader
(
model_path
),
input_names
=
[
'image'
],
output_names
=
[
'fastrcnn_all_probs'
,
'fastrcnn_all_boxes'
,
'fastrcnn_predictions/boxes'
,
'fastrcnn_predictions/probs'
,
'fastrcnn_predictions/labels'
,
]))
img
=
cv2
.
imread
(
input_file
,
cv2
.
IMREAD_COLOR
)
results
=
detect_one_image
(
img
,
pred
)
...
...
@@ -237,7 +241,10 @@ def predict(model_path, input_file):
class
EvalCallback
(
Callback
):
def
_setup_graph
(
self
):
self
.
pred
=
self
.
trainer
.
get_predictor
(
[
'image'
],
[
'fastrcnn_all_probs'
,
'fastrcnn_all_boxes'
])
[
'image'
],
[
'fastrcnn_predictions/boxes'
,
'fastrcnn_predictions/probs'
,
'fastrcnn_predictions/labels'
])
self
.
df
=
PrefetchDataZMQ
(
get_eval_dataflow
(),
1
)
def
_before_train
(
self
):
...
...
examples/FasterRCNN/viz.py
View file @
4db43fec
...
...
@@ -63,21 +63,19 @@ def draw_predictions(img, boxes, scores):
return
viz
.
draw_boxes
(
img
,
boxes
,
tags
)
def
draw_final_outputs
(
img
,
result
s
):
def
draw_final_outputs
(
img
,
final_boxes
,
final_probs
,
final_label
s
):
"""
Args:
results: [DetectionResult]
"""
all_boxes
=
[]
all_tags
=
[]
for
class_id
,
boxes
,
scores
in
results
:
all_boxes
.
extend
(
boxes
)
all_tags
.
extend
(
[
"{},{:.2f}"
.
format
(
COCOMeta
.
class_names
[
class_id
],
sc
)
for
sc
in
scores
])
all_boxes
=
np
.
asarray
(
all_boxes
)
if
all_boxes
.
shape
[
0
]
==
0
:
if
final_boxes
.
shape
[
0
]
==
0
:
return
img
return
viz
.
draw_boxes
(
img
,
all_boxes
,
all_tags
)
tags
=
[]
for
prob
,
label
in
zip
(
final_probs
,
final_labels
):
tags
.
append
(
"{},{:.2f}"
.
format
(
COCOMeta
.
class_names
[
label
],
prob
))
return
viz
.
draw_boxes
(
img
,
final_boxes
,
tags
)
def
draw_mask
(
im
,
mask
,
alpha
=
0.5
,
color
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment