Commit 1aaadca9 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] speedup inference

parent de6a2fed
......@@ -195,59 +195,21 @@ def fastrcnn_predictions(boxes, scores):
boxes = tf.transpose(boxes, [1, 0, 2])[1:, :, :] # #catxnx4
scores = tf.transpose(scores[:, 1:], [1, 0]) # #catxn
def f(X):
"""
prob: n probabilities
box: nx4 boxes
Returns: n boolean, the selection
"""
prob, box = X
output_shape = tf.shape(prob, out_type=tf.int64)
# filter by score threshold
ids = tf.reshape(tf.where(prob > cfg.TEST.RESULT_SCORE_THRESH), [-1])
prob = tf.gather(prob, ids)
box = tf.gather(box, ids)
# NMS within each class
max_coord = tf.reduce_max(boxes)
filtered_ids = tf.where(scores > cfg.TEST.RESULT_SCORE_THRESH) # Fx2
filtered_boxes = tf.gather_nd(boxes, filtered_ids) # Fx4
filtered_scores = tf.gather_nd(scores, filtered_ids) # F,
cls_per_box = tf.slice(filtered_ids, [0, 0], [-1, 1])
offsets = tf.cast(cls_per_box, tf.float32) * (max_coord + 1) # F,1
with tf.device('/cpu:0'):
selection = tf.image.non_max_suppression(
box, prob, cfg.TEST.RESULTS_PER_IM, cfg.TEST.FRCNN_NMS_THRESH)
selection = tf.gather(ids, selection)
if get_tf_version_tuple() >= (1, 13):
sorted_selection = tf.sort(selection, direction='ASCENDING')
mask = tf.sparse.SparseTensor(indices=tf.expand_dims(sorted_selection, 1),
values=tf.ones_like(sorted_selection, dtype=tf.bool),
dense_shape=output_shape)
mask = tf.sparse.to_dense(mask, default_value=False)
else:
# this function is deprecated by TF
sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0]
mask = tf.sparse_to_dense(
sparse_indices=sorted_selection,
output_shape=output_shape,
sparse_values=True,
default_value=False)
return mask
# TF bug in version 1.11, 1.12: https://github.com/tensorflow/tensorflow/issues/22750
buggy_tf = get_tf_version_tuple() in [(1, 11), (1, 12)]
masks = tf.map_fn(f, (scores, boxes), dtype=tf.bool,
parallel_iterations=1 if buggy_tf else 10) # #cat x N
selected_indices = tf.where(masks) # #selection x 2, each is (cat_id, box_id)
scores = tf.boolean_mask(scores, masks)
# filter again by sorting scores
topk_scores, topk_indices = tf.nn.top_k(
scores,
tf.minimum(cfg.TEST.RESULTS_PER_IM, tf.size(scores)),
sorted=False)
filtered_selection = tf.gather(selected_indices, topk_indices)
cat_ids, box_ids = tf.unstack(filtered_selection, axis=1)
final_scores = tf.identity(topk_scores, name='scores')
final_labels = tf.add(cat_ids, 1, name='labels')
final_ids = tf.stack([cat_ids, box_ids], axis=1, name='all_ids')
final_boxes = tf.gather_nd(boxes, final_ids, name='boxes')
filtered_boxes + offsets,
filtered_scores,
cfg.TEST.RESULTS_PER_IM,
cfg.TEST.FRCNN_NMS_THRESH)
final_scores = tf.gather(filtered_scores, selection, name='scores')
final_labels = tf.add(tf.gather(cls_per_box[:, 0], selection), 1, name='labels')
final_boxes = tf.gather(filtered_boxes, selection, name='boxes')
return final_boxes, final_scores, final_labels
......
......@@ -142,6 +142,7 @@ def generate_rpn_proposals(boxes, scores, img_shape,
topk_valid_boxes_y1x1y2x2 = tf.reshape(
tf.reverse(topk_valid_boxes_x1y1x2y2, axis=[2]),
(-1, 4), name='nms_input_boxes')
with tf.device('/cpu:0'):
nms_indices = tf.image.non_max_suppression(
topk_valid_boxes_y1x1y2x2,
topk_valid_scores,
......
......@@ -8,7 +8,6 @@ import os
import time
import tensorflow as tf
from six.moves import map, queue
from tensorflow.python.client import timeline
import psutil
from ..tfutils.common import gpu_available_in_session
......@@ -200,6 +199,7 @@ class GraphProfiler(Callback):
f.write(metadata.SerializeToString())
def _write_tracing(self, metadata):
from tensorflow.python.client import timeline
tl = timeline.Timeline(step_stats=metadata.step_stats)
fname = os.path.join(
self._dir, 'chrome-trace-{}.json'.format(self.global_step))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment