Last active
August 9, 2024 02:42
-
-
Save emaballarin/2e6a5087f84c2b69479a765937c668b9 to your computer and use it in GitHub Desktop.
Implementation of the Xi correlation coefficient in pure PyTorch. After: S. Chatterjee, "A New Coefficient of Correlation", 2020.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# ────────────────────────────────────────────────────────────────────────────── | |
from math import sqrt | |
from typing import List | |
from typing import Optional | |
from typing import Tuple | |
import torch | |
from safe_assert import safe_assert as sassert | |
from torch import Tensor | |
from torch.distributions import Normal | |
# ────────────────────────────────────────────────────────────────────────────── | |
__all__: List[str] = ["flat_rankdata", "xicor"] | |
# ────────────────────────────────────────────────────────────────────────────── | |
def _rankdata_extr(x: Tensor, extrmax: bool = False) -> Tensor: | |
sorted_indices: Tensor = torch.argsort(x, stable=True) | |
ranks: Tensor = torch.zeros_like(x, device=x.device) | |
crank: int = len(ranks) if extrmax else 1 | |
for i in range(len(x))[:: (1 - 2 * extrmax)]: | |
idx: Tensor = sorted_indices[i] | |
if ((i > 0 and not extrmax) or (i < len(ranks) - 1 and extrmax)) and x[ | |
idx | |
] != x[sorted_indices[i + 2 * extrmax - 1]]: | |
crank: int = i + 1 | |
ranks[idx] = crank | |
return ranks - 1 | |
# ────────────────────────────────────────────────────────────────────────────── | |
def flat_rankdata(x: Tensor, method: str = "ordinal", offset: int = 0) -> Tensor: | |
""" | |
Pure-PyTorch re-implementation of SciPy's stats.rankdata on flattened Tensors. | |
Args: | |
x (Tensor): Input tensor. | |
method (str): Method to use for ranking. One of `ordinal`, `min`, or `max`. | |
offset (int): Offset to add to the ranks (for compatibility with SciPy, use 1). | |
Returns: | |
Tensor: Ranks of the flattened input tensor. | |
""" | |
sassert( | |
method in ["ordinal", "min", "max"], | |
"method must be either `ordinal`, `min`, or `max`", | |
) | |
xf: Tensor = x.view(-1) | |
if method == "ordinal": | |
tbret: Tensor = torch.argsort(torch.argsort(xf, stable=True), stable=True) | |
else: | |
tbret: Tensor = _rankdata_extr(xf, method == "max") | |
return tbret.view(x.shape) + offset | |
# ────────────────────────────────────────────────────────────────────────────── | |
def xicor(x: Tensor, y: Tensor, ties: Optional[bool] = None) -> Tuple[Tensor, Tensor]: | |
""" | |
Computes the xi correlation coefficient between two tensors, defined after | |
S. Chatterjee, "A New Coefficient of Correlation", 2020. | |
Args: | |
x (Tensor): First input tensor. | |
y (Tensor): Second input tensor. | |
ties (Optional[bool]): Whether to consider ties in the ranking of `y`. If None, it is inferred automatically. | |
Returns: | |
Tuple[Tensor, Tensor]: The xi correlation coefficient and its p-value. | |
""" | |
x: Tensor = x.view(-1) | |
y: Tensor = y.view(-1) | |
tlen: int = x.numel() | |
sassert( | |
tlen == (ytlen := y.numel()), | |
f"Input tensors must have the same number of elements. Got {tlen} and {ytlen}.", | |
) | |
ties: bool = (torch.unique(y).numel() < tlen) if ties is None else ties | |
y: Tensor = y[torch.argsort(x, stable=True)] # | |
r: Tensor = flat_rankdata(y, method="ordinal", offset=1) | |
num: Tensor = torch.sum(torch.abs(torch.diff(r))) | |
if ties: | |
ll: Tensor = flat_rankdata(y, method="max", offset=1) | |
num: Tensor = num * tlen | |
den: Tensor = 2 * torch.sum(ll * (tlen - ll)) | |
else: | |
num: Tensor = num * 3 | |
den: int = tlen**2 - 1 | |
sret: Tensor = 1 - num / den | |
pret = 1 - Normal( | |
loc=0, | |
scale=torch.tensor(2 / 5 / sqrt(tlen), dtype=torch.float64, device=x.device), | |
).cdf(sret) | |
return sret, pret |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment