Commit 7f55d502 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] suppress tf logging; fix nan precision (#494)

parent 39afd64d
......@@ -226,7 +226,7 @@ def get_train_dataflow(add_mask=False):
if not len(boxes):
raise MalformedData("No valid gt_boxes!")
except MalformedData as e:
log_once("Input {} is invalid for training: {}".format(fname, str(e)), 'warn')
log_once("Input {} is filtered for training: {}".format(fname, str(e)), 'warn')
return None
ret = [im, fm_labels, fm_boxes, boxes, klass]
......
......@@ -78,16 +78,16 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
with tf.device('/cpu:0'):
for th in [0.5, 0.2, 0.1]:
valid_prediction = tf.cast(valid_label_prob > th, tf.int32)
nr_pos_prediction = tf.count_nonzero(valid_prediction, name='num_pos_prediction')
nr_pos_prediction = tf.reduce_sum(valid_prediction, name='num_pos_prediction')
pos_prediction_corr = tf.count_nonzero(tf.logical_and(
valid_label_prob > th,
tf.equal(valid_prediction, valid_anchor_labels)))
summaries.append(tf.truediv(
pos_prediction_corr,
nr_pos, name='recall_th{}'.format(th)))
summaries.append(tf.truediv(
pos_prediction_corr,
nr_pos_prediction, name='precision_th{}'.format(th)))
precision = tf.truediv(pos_prediction_corr, nr_pos_prediction)
precision = tf.where(tf.equal(nr_pos_prediction, 0), 0, precision, name='precision_th{}'.format(th))
summaries.append(precision)
label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits)
......
......@@ -75,10 +75,12 @@ def get_iou_callable():
"""
Get a pairwise box iou callable.
"""
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['CUDA_VISIBLE_DEVICES'] = '' # we don't want the dataflow process to touch CUDA
with tf.Graph().as_default(), tf.device('/cpu:0'):
A = tf.placeholder(tf.float32, shape=[None, 4])
B = tf.placeholder(tf.float32, shape=[None, 4])
iou = pairwise_iou(A, B)
tf.logging.set_verbosity(tf.logging.FATAL) # TF will warn about GPU not found
sess = tf.Session(config=get_default_sess_config())
tf.logging.set_verbosity(tf.logging.INFO)
return sess.make_callable(iou, [A, B])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment