Commit eb1fa920 authored by Yuxin Wu's avatar Yuxin Wu

fix nan in class_balanced_sigmoid_cross_entropy when only 1 label exists.

parent 8ad7e2b4
...@@ -77,8 +77,8 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss ...@@ -77,8 +77,8 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
pos_weight = beta / (1 - beta) pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight) cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name) cost = tf.reduce_mean(cost * (1 - beta))
return cost return tf.where(tf.equal(count_pos, 0.0), 0.0, cost, name=name)
def print_stat(x, message=None): def print_stat(x, message=None):
......
...@@ -26,7 +26,7 @@ class FeedfreeTrainerBase(Trainer): ...@@ -26,7 +26,7 @@ class FeedfreeTrainerBase(Trainer):
def _trigger_epoch(self): def _trigger_epoch(self):
# run summary_op every epoch # run summary_op every epoch
# TODO summary_op will take a data! This is not good for TensorInput. # TODO FIXME summary_op will take a data! This is not good for TensorInput.
if self.summary_op is not None: if self.summary_op is not None:
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self.add_summary(summary_str) self.add_summary(summary_str)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment