Commit f2ca6b1a authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] parameterize `get_tf_nms` and make it in a standalone graph.

parent a042f821
......@@ -27,17 +27,18 @@ DetectionResult = namedtuple(
@memoized
def get_tf_nms():
def get_tf_nms(num_output, thresh):
"""
Get a NMS callable.
"""
boxes = tf.placeholder(tf.float32, shape=[None, 4])
scores = tf.placeholder(tf.float32, shape=[None])
indices = tf.image.non_max_suppression(
boxes, scores,
config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH)
sess = tf.Session(config=get_default_sess_config())
return sess.make_callable(indices, [boxes, scores])
# create a new graph for it
with tf.Graph().as_default(), tf.device('/cpu:0'):
boxes = tf.placeholder(tf.float32, shape=[None, 4])
scores = tf.placeholder(tf.float32, shape=[None])
indices = tf.image.non_max_suppression(
boxes, scores, num_output, thresh)
sess = tf.Session(config=get_default_sess_config())
return sess.make_callable(indices, [boxes, scores])
def nms_fastrcnn_results(boxes, probs):
......@@ -53,7 +54,7 @@ def nms_fastrcnn_results(boxes, probs):
boxes = boxes.copy()
boxes_per_class = {}
nms_func = get_tf_nms()
nms_func = get_tf_nms(config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH)
ret = []
for klass in range(1, C):
ids = np.where(probs[:, klass] > config.RESULT_SCORE_THRESH)[0]
......
......@@ -232,7 +232,6 @@ class EvalCallback(Callback):
def _setup_graph(self):
self.pred = self.trainer.get_predictor(['image'], ['fastrcnn_fg_probs', 'fastrcnn_fg_boxes'])
self.df = PrefetchDataZMQ(get_eval_dataflow(), 1)
get_tf_nms() # just to make sure the nms part of graph is created
def _before_train(self):
EVAL_TIMES = 5 # eval 5 times during training
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment