Skip to content

Instantly share code, notes, and snippets.

Last active May 12, 2021 15:41
Show Gist options
  • Save alexlimh/ec15f49c0ef46ac487ea4e89afb8b217 to your computer and use it in GitHub Desktop.
Save alexlimh/ec15f49c0ef46ac487ea4e89afb8b217 to your computer and use it in GitHub Desktop.
mutual information
# 1. retrieve top-k passages for each query
# 2. merge the top-k passages and get the union
# 3. calculate the dot-product on the union for each query
# 3.5. save the dot-products into .pkl or .npy file
# 4. put the score into this function and get mutual information
def get_mi(self, logits, tau=1.0): # logits are the dot-product scores, size: QxP, where Q is the number of queries, P is the number of retrieved passages
# tau is the temperature coef that controls the sharpness of the distribution
# numeric stability
max_logit = np.max(logits, axis=-1, keepdims=True)
probs = np.exp((logits - max_logit)/tau)/np.sum(np.exp((logits - max_logit)/tau), -1, keepdims=True)
# p(y)
marg_probs = np.mean(probs, axis=0, keepdims=True)
# H(y)
marg_ent = -np.sum(marg_probs * np.log(marg_probs+1e-6), axis=-1)
# H(y|x)
cond_ent = -np.sum(probs * np.log(probs+1e-6), axis=-1)
# MI = H(y) - H(y|x)
mi = np.mean(marg_ent - cond_ent, axis=0)
# Normalize to [0,1]
mi = mi / np.log(len(logits))
return mi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment