Commit f2ca6b1a authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] parameterize `get_tf_nms` and make it in a standalone graph.

parent a042f821
...@@ -27,17 +27,18 @@ DetectionResult = namedtuple( ...@@ -27,17 +27,18 @@ DetectionResult = namedtuple(
@memoized @memoized
def get_tf_nms(): def get_tf_nms(num_output, thresh):
""" """
Get a NMS callable. Get a NMS callable.
""" """
boxes = tf.placeholder(tf.float32, shape=[None, 4]) # create a new graph for it
scores = tf.placeholder(tf.float32, shape=[None]) with tf.Graph().as_default(), tf.device('/cpu:0'):
indices = tf.image.non_max_suppression( boxes = tf.placeholder(tf.float32, shape=[None, 4])
boxes, scores, scores = tf.placeholder(tf.float32, shape=[None])
config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH) indices = tf.image.non_max_suppression(
sess = tf.Session(config=get_default_sess_config()) boxes, scores, num_output, thresh)
return sess.make_callable(indices, [boxes, scores]) sess = tf.Session(config=get_default_sess_config())
return sess.make_callable(indices, [boxes, scores])
def nms_fastrcnn_results(boxes, probs): def nms_fastrcnn_results(boxes, probs):
...@@ -53,7 +54,7 @@ def nms_fastrcnn_results(boxes, probs): ...@@ -53,7 +54,7 @@ def nms_fastrcnn_results(boxes, probs):
boxes = boxes.copy() boxes = boxes.copy()
boxes_per_class = {} boxes_per_class = {}
nms_func = get_tf_nms() nms_func = get_tf_nms(config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH)
ret = [] ret = []
for klass in range(1, C): for klass in range(1, C):
ids = np.where(probs[:, klass] > config.RESULT_SCORE_THRESH)[0] ids = np.where(probs[:, klass] > config.RESULT_SCORE_THRESH)[0]
......
...@@ -232,7 +232,6 @@ class EvalCallback(Callback): ...@@ -232,7 +232,6 @@ class EvalCallback(Callback):
def _setup_graph(self): def _setup_graph(self):
self.pred = self.trainer.get_predictor(['image'], ['fastrcnn_fg_probs', 'fastrcnn_fg_boxes']) self.pred = self.trainer.get_predictor(['image'], ['fastrcnn_fg_probs', 'fastrcnn_fg_boxes'])
self.df = PrefetchDataZMQ(get_eval_dataflow(), 1) self.df = PrefetchDataZMQ(get_eval_dataflow(), 1)
get_tf_nms() # just to make sure the nms part of graph is created
def _before_train(self): def _before_train(self):
EVAL_TIMES = 5 # eval 5 times during training EVAL_TIMES = 5 # eval 5 times during training
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment