Commit 1aaadca9 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] speedup inference

parent de6a2fed
...@@ -195,59 +195,21 @@ def fastrcnn_predictions(boxes, scores): ...@@ -195,59 +195,21 @@ def fastrcnn_predictions(boxes, scores):
boxes = tf.transpose(boxes, [1, 0, 2])[1:, :, :] # #catxnx4 boxes = tf.transpose(boxes, [1, 0, 2])[1:, :, :] # #catxnx4
scores = tf.transpose(scores[:, 1:], [1, 0]) # #catxn scores = tf.transpose(scores[:, 1:], [1, 0]) # #catxn
def f(X): max_coord = tf.reduce_max(boxes)
""" filtered_ids = tf.where(scores > cfg.TEST.RESULT_SCORE_THRESH) # Fx2
prob: n probabilities filtered_boxes = tf.gather_nd(boxes, filtered_ids) # Fx4
box: nx4 boxes filtered_scores = tf.gather_nd(scores, filtered_ids) # F,
cls_per_box = tf.slice(filtered_ids, [0, 0], [-1, 1])
Returns: n boolean, the selection offsets = tf.cast(cls_per_box, tf.float32) * (max_coord + 1) # F,1
""" with tf.device('/cpu:0'):
prob, box = X
output_shape = tf.shape(prob, out_type=tf.int64)
# filter by score threshold
ids = tf.reshape(tf.where(prob > cfg.TEST.RESULT_SCORE_THRESH), [-1])
prob = tf.gather(prob, ids)
box = tf.gather(box, ids)
# NMS within each class
selection = tf.image.non_max_suppression( selection = tf.image.non_max_suppression(
box, prob, cfg.TEST.RESULTS_PER_IM, cfg.TEST.FRCNN_NMS_THRESH) filtered_boxes + offsets,
selection = tf.gather(ids, selection) filtered_scores,
cfg.TEST.RESULTS_PER_IM,
if get_tf_version_tuple() >= (1, 13): cfg.TEST.FRCNN_NMS_THRESH)
sorted_selection = tf.sort(selection, direction='ASCENDING') final_scores = tf.gather(filtered_scores, selection, name='scores')
mask = tf.sparse.SparseTensor(indices=tf.expand_dims(sorted_selection, 1), final_labels = tf.add(tf.gather(cls_per_box[:, 0], selection), 1, name='labels')
values=tf.ones_like(sorted_selection, dtype=tf.bool), final_boxes = tf.gather(filtered_boxes, selection, name='boxes')
dense_shape=output_shape)
mask = tf.sparse.to_dense(mask, default_value=False)
else:
# this function is deprecated by TF
sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0]
mask = tf.sparse_to_dense(
sparse_indices=sorted_selection,
output_shape=output_shape,
sparse_values=True,
default_value=False)
return mask
# TF bug in version 1.11, 1.12: https://github.com/tensorflow/tensorflow/issues/22750
buggy_tf = get_tf_version_tuple() in [(1, 11), (1, 12)]
masks = tf.map_fn(f, (scores, boxes), dtype=tf.bool,
parallel_iterations=1 if buggy_tf else 10) # #cat x N
selected_indices = tf.where(masks) # #selection x 2, each is (cat_id, box_id)
scores = tf.boolean_mask(scores, masks)
# filter again by sorting scores
topk_scores, topk_indices = tf.nn.top_k(
scores,
tf.minimum(cfg.TEST.RESULTS_PER_IM, tf.size(scores)),
sorted=False)
filtered_selection = tf.gather(selected_indices, topk_indices)
cat_ids, box_ids = tf.unstack(filtered_selection, axis=1)
final_scores = tf.identity(topk_scores, name='scores')
final_labels = tf.add(cat_ids, 1, name='labels')
final_ids = tf.stack([cat_ids, box_ids], axis=1, name='all_ids')
final_boxes = tf.gather_nd(boxes, final_ids, name='boxes')
return final_boxes, final_scores, final_labels return final_boxes, final_scores, final_labels
......
...@@ -142,6 +142,7 @@ def generate_rpn_proposals(boxes, scores, img_shape, ...@@ -142,6 +142,7 @@ def generate_rpn_proposals(boxes, scores, img_shape,
topk_valid_boxes_y1x1y2x2 = tf.reshape( topk_valid_boxes_y1x1y2x2 = tf.reshape(
tf.reverse(topk_valid_boxes_x1y1x2y2, axis=[2]), tf.reverse(topk_valid_boxes_x1y1x2y2, axis=[2]),
(-1, 4), name='nms_input_boxes') (-1, 4), name='nms_input_boxes')
with tf.device('/cpu:0'):
nms_indices = tf.image.non_max_suppression( nms_indices = tf.image.non_max_suppression(
topk_valid_boxes_y1x1y2x2, topk_valid_boxes_y1x1y2x2,
topk_valid_scores, topk_valid_scores,
......
...@@ -8,7 +8,6 @@ import os ...@@ -8,7 +8,6 @@ import os
import time import time
import tensorflow as tf import tensorflow as tf
from six.moves import map, queue from six.moves import map, queue
from tensorflow.python.client import timeline
import psutil import psutil
from ..tfutils.common import gpu_available_in_session from ..tfutils.common import gpu_available_in_session
...@@ -200,6 +199,7 @@ class GraphProfiler(Callback): ...@@ -200,6 +199,7 @@ class GraphProfiler(Callback):
f.write(metadata.SerializeToString()) f.write(metadata.SerializeToString())
def _write_tracing(self, metadata): def _write_tracing(self, metadata):
from tensorflow.python.client import timeline
tl = timeline.Timeline(step_stats=metadata.step_stats) tl = timeline.Timeline(step_stats=metadata.step_stats)
fname = os.path.join( fname = os.path.join(
self._dir, 'chrome-trace-{}.json'.format(self.global_step)) self._dir, 'chrome-trace-{}.json'.format(self.global_step))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment