Source code for spear.Implyloss.gen_cross_entropy_utils

# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

# implementation of generalizer cross entropy loss 
# eq6 of paper https://arxiv.org/pdf/1805.07836.pdf 


# loss for multiclass prediction problem 
# (logits corresponding to a softmax distribution and corresponding labels)
[docs]def generalized_cross_entropy(logits, one_hot_labels,q=0.6): ''' Func Desc: Computes the generalized cross entropy loss Input: logits([batch_size, num_classes]) - weights one_hot_labels([batch_size, num_classes]) q (default = 0.6) Output: loss ''' #logits: [batch_size, num_classes] #one_hot_labels: [batch_size, num_classes] if q == 0.0: #for q=0 in limit, this is usual cross entropy loss = tf.nn.softmax_cross_entropy_with_logits(labels=one_hot_labels,logits=logits) loss = tf.reduce_mean(loss) else: exp_logits = tf.exp(logits) normalizer = tf.reduce_sum(exp_logits,axis=-1)#,keep_dims=True) normalizer_q = tf.pow(normalizer,q) exp_logits_q = tf.exp(logits*q) f_j_q = exp_logits_q / normalizer_q loss = (1 - f_j_q)/q loss = tf.reduce_sum(loss * one_hot_labels, axis=-1) loss = tf.reduce_mean(loss) return loss
# loss for a particular class # p = probability of that class
[docs]def generalized_cross_entropy_bernoulli(p,q=0.2): ''' Func Desc: computes the bernoulli generalized cross entropy Input: p - base q (default = 0.2) - exponent Output: loss ''' return (1 - tf.pow(p,q))/q