Commit ced8330e authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] fix type mismatch (#494)

parent 7510f165
......@@ -66,8 +66,8 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
with tf.device('/cpu:0'):
valid_mask = tf.stop_gradient(tf.not_equal(anchor_labels, -1))
pos_mask = tf.stop_gradient(tf.equal(anchor_labels, 1))
nr_valid = tf.stop_gradient(tf.count_nonzero(valid_mask), name='num_valid_anchor')
nr_pos = tf.count_nonzero(pos_mask, name='num_pos_anchor')
nr_valid = tf.stop_gradient(tf.count_nonzero(valid_mask, dtype=tf.int32), name='num_valid_anchor')
nr_pos = tf.count_nonzero(pos_mask, dtype=tf.int32, name='num_pos_anchor')
valid_anchor_labels = tf.boolean_mask(anchor_labels, valid_mask)
valid_label_logits = tf.boolean_mask(label_logits, valid_mask)
......@@ -79,14 +79,16 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
for th in [0.5, 0.2, 0.1]:
valid_prediction = tf.cast(valid_label_prob > th, tf.int32)
nr_pos_prediction = tf.reduce_sum(valid_prediction, name='num_pos_prediction')
pos_prediction_corr = tf.count_nonzero(tf.logical_and(
valid_label_prob > th,
tf.equal(valid_prediction, valid_anchor_labels)))
pos_prediction_corr = tf.count_nonzero(
tf.logical_and(
valid_label_prob > th,
tf.equal(valid_prediction, valid_anchor_labels)),
dtype=tf.int32)
summaries.append(tf.truediv(
pos_prediction_corr,
nr_pos, name='recall_th{}'.format(th)))
precision = tf.truediv(pos_prediction_corr, nr_pos_prediction)
precision = tf.where(tf.equal(nr_pos_prediction, 0), 0, precision, name='precision_th{}'.format(th))
precision = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos_prediction))
precision = tf.where(tf.equal(nr_pos_prediction, 0), 0.0, precision, name='precision_th{}'.format(th))
summaries.append(precision)
label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment