Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created January 15, 2021 21:59
Show Gist options
  • Save crowsonkb/ecda42d6015622a6218d2fcda82adca2 to your computer and use it in GitHub Desktop.
Save crowsonkb/ecda42d6015622a6218d2fcda82adca2 to your computer and use it in GitHub Desktop.
Computes the covariance matrix in PyTorch.
"""Computes the covariance matrix in PyTorch."""
def cov_mean(input, unbiased=True, keepdims=False):
n = input.shape[-1] - unbiased
mean = input.mean(dim=-1, keepdims=True)
dev = input - mean
mean = mean if keepdims else mean[..., 0]
return dev @ dev.transpose(-1, -2) / n, mean
def cov(input, unbiased=True):
return cov_mean(input, unbiased=unbiased)[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment