Created
November 3, 2021 14:43
-
-
Save vene/c62262904355b7d39356b95c829d4745 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Approximating the cross-entropy between two Power Sphericals. | |
Uses a second-order Taylor expansion to approximate E[log(1+z)]. | |
""" | |
# author: vlad n <[email protected]> | |
# license: mit | |
# documentation: https://hackmd.io/@vladn/SJ93wMevK | |
import numpy as np | |
import torch | |
from power_spherical import PowerSpherical, HypersphericalUniform | |
def ps_variance_dot(ps, y): | |
"""Computes the variance of dot(x, y) for x~ps and norm(y) = 1.""" | |
alpha = ps.base_dist.marginal_t.base_dist.concentration1 | |
beta = ps.base_dist.marginal_t.base_dist.concentration0 | |
ratio = (alpha + beta) / (2 * beta) | |
t_var = ps.base_dist.marginal_t.variance | |
dp = ps.loc @ y # check dimension | |
yy = 1 # yy = y @ y, but we know this to be 1. | |
return t_var * ((1 - ratio) * dp ** 2 + ratio * yy) # + mean_sq - mean_sq | |
def check_ps_variance_dot( | |
dim=10, | |
k=20, | |
n_samples=1000): | |
dim = torch.tensor(dim) | |
k = torch.tensor(k) | |
unif = HypersphericalUniform(dim=dim) | |
mu_p = unif.rsample() | |
mu_q = unif.rsample() | |
p = PowerSpherical(loc=mu_p, scale=k) | |
xp = p.rsample((n_samples,)) | |
z = xp @ mu_q | |
# true mean | |
z_mean = p.mean @ mu_q | |
print("mean z true:", z_mean.item()) | |
print("mean z num: ", torch.mean(z).item()) | |
print("V[z] tru: ", ps_variance_dot(p, mu_q).item()) | |
print("V[z] num: ", torch.var(z).item()) | |
def check_taylor( | |
dim=10, | |
k=3, | |
n_samples=10000): | |
dim = torch.tensor(dim) | |
k = torch.tensor(k) | |
unif = HypersphericalUniform(dim=dim) | |
mu_p = unif.rsample() | |
mu_q = unif.rsample() | |
p = PowerSpherical(loc=mu_p, scale=k) | |
xp = p.rsample((n_samples,)) | |
# approximate E[ log(1+z) ] where z = dot(x,y), x~ps | |
z = xp @ mu_q | |
z_mean = p.mean @ mu_q | |
taylor_first = torch.log1p(z_mean) | |
taylor_second = ps_variance_dot(p, mu_q) / (2 * (1 + z_mean) ** 2) | |
taylor = taylor_first - taylor_second | |
print("MC: ", torch.mean(torch.log1p(z)).item()) | |
print("Tay: ", taylor.item()) | |
def main(): | |
check_ps_variance_dot() | |
check_taylor() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment