Last active
May 12, 2021 15:41
-
-
Save alexlimh/ec15f49c0ef46ac487ea4e89afb8b217 to your computer and use it in GitHub Desktop.
mutual information
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. retrieve top-k passages for each query | |
# 2. merge the top-k passages and get the union | |
# 3. calculate the dot-product on the union for each query | |
# 3.5. save the dot-products into .pkl or .npy file | |
# 4. put the score into this function and get mutual information | |
def get_mi(self, logits, tau=1.0): # logits are the dot-product scores, size: QxP, where Q is the number of queries, P is the number of retrieved passages | |
# tau is the temperature coef that controls the sharpness of the distribution | |
# numeric stability | |
max_logit = np.max(logits, axis=-1, keepdims=True) | |
probs = np.exp((logits - max_logit)/tau)/np.sum(np.exp((logits - max_logit)/tau), -1, keepdims=True) | |
# p(y) | |
marg_probs = np.mean(probs, axis=0, keepdims=True) | |
# H(y) | |
marg_ent = -np.sum(marg_probs * np.log(marg_probs+1e-6), axis=-1) | |
# H(y|x) | |
cond_ent = -np.sum(probs * np.log(probs+1e-6), axis=-1) | |
# MI = H(y) - H(y|x) | |
mi = np.mean(marg_ent - cond_ent, axis=0) | |
# Normalize to [0,1] | |
mi = mi / np.log(len(logits)) | |
return mi |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment