Commit eb1fa920 authored by Yuxin Wu's avatar Yuxin Wu

fix nan in class_balanced_sigmoid_cross_entropy when only 1 label exists.

parent 8ad7e2b4
......@@ -77,8 +77,8 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name)
return cost
cost = tf.reduce_mean(cost * (1 - beta))
return tf.where(tf.equal(count_pos, 0.0), 0.0, cost, name=name)
def print_stat(x, message=None):
......
......@@ -26,7 +26,7 @@ class FeedfreeTrainerBase(Trainer):
def _trigger_epoch(self):
# run summary_op every epoch
# TODO summary_op will take a data! This is not good for TensorInput.
# TODO FIXME summary_op will take a data! This is not good for TensorInput.
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self.add_summary(summary_str)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment