Created
January 15, 2021 21:59
-
-
Save crowsonkb/ecda42d6015622a6218d2fcda82adca2 to your computer and use it in GitHub Desktop.
Computes the covariance matrix in PyTorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Computes the covariance matrix in PyTorch.""" | |
def cov_mean(input, unbiased=True, keepdims=False): | |
n = input.shape[-1] - unbiased | |
mean = input.mean(dim=-1, keepdims=True) | |
dev = input - mean | |
mean = mean if keepdims else mean[..., 0] | |
return dev @ dev.transpose(-1, -2) / n, mean | |
def cov(input, unbiased=True): | |
return cov_mean(input, unbiased=unbiased)[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment