Created
January 25, 2024 14:07
-
-
Save samarthbhargav/950e72a0c076d01ebc470266e8ef8ff1 to your computer and use it in GitHub Desktop.
Pairwise KL Divergence between two diagonal gaussian distributions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def pairwise_kl_divergence(mean1, var1, mean2, var2): | |
# Ensure that variances are non-negative | |
var1 = torch.clamp(var1, min=1e-10) | |
var2 = torch.clamp(var2, min=1e-10) | |
k = mean1.size(1) | |
# shape = BZ_1xD | |
logvar1 = torch.log(var1) | |
# log determinant | |
logvar1det = logvar1.sum(1) | |
# shape = BZ_2xD | |
logvar2= torch.log(var2) | |
logvar2det = logvar2.sum(1) | |
# matrix of log(det(var2)) - log(det(var1)) - k | |
# shape = BZ_1, BZ_2 where (i,j) = (i+j) | |
log_var_diff = -logvar1det.reshape(-1, 1) + logvar2det - k | |
# inverse of var2 | |
var2inv = 1/var2 | |
# trace(var2^-1. var1) if both var1/var2 are diagonal | |
tr_prod = var1.matmul(var2inv.T) | |
# mudiff_sq - shape of BZ_1xBZ_2xD | |
mudiff_sq = (mean1.reshape(-1, 1, k) - mean2) ** 2 | |
diff_div = (mudiff_sq * var2inv).sum(dim=-1) | |
kl = -0.5 * (log_var_diff + tr_prod + diff_div) | |
return kl | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment