import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
[docs]def exp_term_for_constraints(rule_classes, num_classes, C):
'''
Func Desc:
Compute the exponential term for the constraints
Input:
rule_classes ([num_rules,1]) - a list of classes associated with the rules
num_classes (int)
C
Output:
the required exponential term
'''
rule_classes_tensor = tf.to_float(tf.convert_to_tensor(rule_classes))
#rule_classes_tensor = tf.reshape(rule_classes_tensor,[1,rule_classes])
rule_classes_tensor = tf.expand_dims(rule_classes_tensor,0)
class_types_tensor = tf.to_float(tf.convert_to_tensor(np.arange(num_classes).reshape(num_classes,1)))
#[num_classes,num_rules]
class_rule_constraint = tf.to_float(tf.equal(class_types_tensor,rule_classes_tensor)) - 1.0
class_rule_constraint = tf.exp(C*class_rule_constraint)
return class_rule_constraint
[docs]def pr_product_term(weights, rule_classes, num_classes, C):
'''
Func Desc:
Compute the probability product term for the constraints
Input:
weights ([batch_size, num_rules]) - the w_network weights
rule_classes ([num_rules,1]) - a list of classes associated with the rules
num_classes (int)
C
Output:
the required product term
'''
# weights: [batch_size, num_rules]
class_rule_constraint = exp_term_for_constraints(rule_classes, num_classes, C)
#class_rule_constraint = tf.Print(class_rule_constraint,[tf.shape(class_rule_constraint)],message="class_rule_constraint")
#[num_classes,1,num_rules]
class_rule_constraint = tf.expand_dims(class_rule_constraint,axis=1)
#[1,batch_size,num_rules]
weights = tf.expand_dims(weights,axis=0)
# [num_classes,batch_size,num_rules]
t1 = class_rule_constraint * weights
# [1, batch_size, num_rules]
t2 = 1-weights
#[num_classes,batch_size,num_rules]
t = t1+t2
#t = tf.Print(t, [t,tf.shape(t)],message="t and shape of t")
product_term = tf.reduce_prod(t,axis=-1)
#[batch_size, num_classes]
product_term = tf.transpose(product_term)
return product_term
[docs]def get_q_y_from_p(f_probs, weights, rule_classes, num_classes, C):
'''
Func Desc:
Compute the q_y term from the p (f_network) distribution
Input:
f_probs ([batch_size, num_classes])
weights ([batch_size, num_rules]) - the w_network weights
rule_classes ([num_rules,1]) - a list of classes associated with the rules
num_classes (int)
C
Output:
the required q_y term
'''
# f_probs: [batch_size, num_classes]
# weights: [batch_size, num_rules]
psi = 1e-20
product_term = pr_product_term(weights, rule_classes, num_classes, C)
result = f_probs * product_term
normalizer = tf.reduce_sum(result,axis=-1,keepdims=True)
result = result/(normalizer + psi)
return result
[docs]def get_q_r_from_p(f_probs, weights, rule_classes, num_classes, C):
'''
Func Desc:
Compute the q_r term from the p (f_network) distribution
Input:
f_probs ([batch_size, num_classes])
weights ([batch_size, num_rules]) - the w_network weights
rule_classes ([num_rules,1]) - a list of classes associated with the rules
num_classes (int)
C
Output:
the required q_r term
'''
# f_probs: [batch_size, num_classes]
# weights: [batch_size, num_rules]
psi = 1e-20
#[batch_size, num_classes]
pr_product_t = pr_product_term(weights, rule_classes, num_classes, C)
#[batch_size, 1, num_classes]
product_term = tf.expand_dims(pr_product_t,axis=1)
#[num_rules, num_classes]
class_rule_constraint = tf.transpose(exp_term_for_constraints(rule_classes, num_classes, C))
#[1, num_rules, num_classes]
class_rule_constraint = tf.expand_dims(class_rule_constraint, axis=0)
#[batch_size, num_rules, 1]
w = tf.expand_dims(weights,2)
#[batch_size, num_rules, num_classes]
divisior = w*class_rule_constraint + (1-w)
product_term = product_term / (divisior + psi)
#[batch_size, 1, num_classes]
f_probs = tf.expand_dims(f_probs,axis=1)
#[batch_size, num_rules, num_classes]
product_term = product_term * f_probs * class_rule_constraint
#[batch_size, num_rules]
sum_over_y_term = tf.reduce_sum(product_term,axis=-1)
prob_q_r_eq_1 = weights * sum_over_y_term
prob_q_r_eq_0 = tf.reduce_sum(f_probs * product_term, axis=-1)
prob_q_r_eq_0 = (1 - weights) * prob_q_r_eq_0
prob_q_r_eq_1 = prob_q_r_eq_1 / (prob_q_r_eq_0 + prob_q_r_eq_1)
return prob_q_r_eq_1
[docs]def theta_term_in_pr_loss(f_logits, f_probs, weights, rule_classes, num_classes, C, d):
'''
Func Desc:
Compute the theta term in the pr loss
Input:
f_logits ([batch_size, num_classes])
f_probs ([batch_size, num_classes])
weights ([batch_size, num_rules]) - the w_network weights
rule_classes ([num_rules,1]) - a list of classes associated with the rules
num_classes (int)
C
d ([batch_size,1])
Output:
the required theta term (third term in equation 14) - used to supervise f (classification) network from instances in U
'''
#[batch_size, num_classes]
q_y = get_q_y_from_p(f_probs, weights, rule_classes, num_classes, C)
cross_ent_q_p = tf.nn.softmax_cross_entropy_with_logits(labels=q_y,logits=f_logits)
cross_ent_q_p = (1-d) * cross_ent_q_p #defined only for instances in U, so mask by (1-d)
result = tf.reduce_mean(cross_ent_q_p)
return result
[docs]def phi_term_in_pr_loss(m, w_logits, f_probs, weights, rule_classes, num_classes, C, d):
'''
Func Desc:
Compute the phi term in the pr loss
Input:
m ([batch_size, num_rules]) - mij = 1 if ith example is associated with jth rule
w_logits ([batch_size, num_rules])
f_probs ([batch_size, num_classes])
weights ([batch_size, num_rules]) - the w_network weights
rule_classes ([num_rules,1]) - a list of classes associated with the rules
num_classes (int)
C
d ([batch_size,1])
Output:
the required phi term (fourth term in equation 14) - used to superwise w (rule) network from instances in U
'''
#w_logits: [batch_size, num_rules]
#m: [batch_size, num_rules]
psi = 1e-20
q_r_1 = get_q_r_from_p(f_probs, weights, rule_classes, num_classes, C)
#[batch_size, num_rules]
cross_ent_q_w = tf.nn.sigmoid_cross_entropy_with_logits(labels=q_r_1, logits=w_logits)
cross_ent_q_w = tf.reduce_sum(cross_ent_q_w*m,axis=-1)
#normalizer_cross_ent_q_w = tf.reduce_sum(m,axis=-1)
#cross_ent_q_w = cross_ent_q_w / (normalizer_cross_ent_q_w + psi)
cross_ent_q_w = cross_ent_q_w * (1-d)
cross_ent_q_w = tf.reduce_mean(cross_ent_q_w)
return cross_ent_q_w
[docs]def pr_loss(m, f_logits, w_logits, f_probs, weights, rule_classes, num_classes, C, d):
'''
Func Desc:
Compute the pr loss
Input:
m ([batch_size, num_rules]) - mij = 1 if ith example is associated with jth rule
f_logits
w_logits ([batch_size, num_rules]) - logit before sigmoid activation in w network
f_probs ([batch_size, num_classes]) - output of f network
weights ([batch_size, num_rules]) - the sigmoid output from w network
rule_classes ([num_rules,1]) - a list of classes associated with the rules
num_classes (int)
C - lamda in equation 10 (hyperparameter)
d ([batch_size,1]) - if ith instance is from "d" set (labelled data) d[i] = 1, else if ith instance is from "U" set, d[i] = 0
Output:
the required phi term
'''
#theta_term : (third term in equation 14) (used to supervise f (classification) network from instances in U )
#phi term : (fourth term in equation 14) (used to superwise w (rule) network from instances in U )
# m : rule_firing matrix: [batch_size, num_rules]
# w_logits: logit before sigmoid activation in w network: [batch_size, num_rules]
# weights: sigmoid output from w network: [batch_size, num_rules]
# f_probs: output of f network: [batch_size, num_classes]
# C: \lamda in equation 10 (hyperparameter)
# d : [batch_size], d[i] = 0 if ith instance is from "U" set, 1 if ith instance is from "d" set (labeled data)
theta_term = theta_term_in_pr_loss(f_logits, f_probs, weights, rule_classes, num_classes, C, d)
cross_ent_q_w = phi_term_in_pr_loss(m, w_logits, f_probs, weights, rule_classes, num_classes, C, d)
return theta_term + cross_ent_q_w