Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active October 19, 2023 03:56
Show Gist options
  • Save norabelrose/459ee8c2e6f3c57165aabdb1367a30ee to your computer and use it in GitHub Desktop.
Save norabelrose/459ee8c2e6f3c57165aabdb1367a30ee to your computer and use it in GitHub Desktop.
Ryan Greenblatt's cumulant estimation code
from typing import Optional
import torch
def get_all_the_cumulants(
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, w: torch.Tensor, weights_in: Optional[torch.Tensor] = None
):
if weights_in is not None:
weights = weights_in
weights = weights / weights.sum()
adj_val = (weights ** 2).sum() # I don't remember why this is right, so use at your own risk : )
else:
s = x.size(0)
weights = torch.full((s,), 1 / s, dtype=x.dtype, device=x.device)
adj_val = torch.tensor(1 / s, dtype=x.dtype, device=x.device)
vals = [x, y, z, w]
all_means = [torch.einsum("i x, i -> x", v, weights).unsqueeze(0) for v in vals]
c_x, c_y, c_z, c_w = [v - m for v, m in zip(vals, all_means)]
x_mean = all_means[0].squeeze(0)
uncorrected_cov = torch.einsum("b i, b j, b -> i j", c_x, c_y, weights)
corrected_cov = uncorrected_cov / (1 - adj_val)
uncorrected_third_cum = torch.einsum("b i, b j, b k, b -> i j k", c_x, c_y, c_z, weights)
corrected_third_cum = uncorrected_third_cum / ((1 - adj_val) * (1 - 2 * adj_val))
centered_4th_mom = torch.einsum("b i, b j, b k, b l, b -> i j k l", c_x, c_y, c_z, c_w, weights)
first_pair = torch.einsum(
"i j, k l -> i j k l",
torch.einsum("b i, b j, b -> i j", c_x, c_y, weights),
torch.einsum("b k, b l, b -> k l", c_z, c_w, weights),
)
snd_pair = torch.einsum(
"i k, j l -> i j k l",
torch.einsum("b i, b k, b -> i k", c_x, c_z, weights),
torch.einsum("b j, b l, b-> j l", c_y, c_w, weights),
)
thrd_pair = torch.einsum(
"i l, j k -> i j k l",
torch.einsum("b i, b l, b -> i l", c_x, c_w, weights),
torch.einsum("b j, b k, b -> j k", c_y, c_z, weights),
)
pair_sum = first_pair + snd_pair + thrd_pair
uncorrected_4th_cum = centered_4th_mom - pair_sum
corrected_4th_cum = (1 / ((1 - adj_val) * (1 - 2 * adj_val) * (1 - 3 * adj_val))) * (
(1 + adj_val) * centered_4th_mom - (1 - adj_val) * (pair_sum)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment