Skip to content

Instantly share code, notes, and snippets.

@emaballarin
Last active August 9, 2024 02:42
Show Gist options
  • Save emaballarin/2e6a5087f84c2b69479a765937c668b9 to your computer and use it in GitHub Desktop.
Save emaballarin/2e6a5087f84c2b69479a765937c668b9 to your computer and use it in GitHub Desktop.
Implementation of the Xi correlation coefficient in pure PyTorch. After: S. Chatterjee, "A New Coefficient of Correlation", 2020.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ──────────────────────────────────────────────────────────────────────────────
from math import sqrt
from typing import List
from typing import Optional
from typing import Tuple
import torch
from safe_assert import safe_assert as sassert
from torch import Tensor
from torch.distributions import Normal
# ──────────────────────────────────────────────────────────────────────────────
__all__: List[str] = ["flat_rankdata", "xicor"]
# ──────────────────────────────────────────────────────────────────────────────
def _rankdata_extr(x: Tensor, extrmax: bool = False) -> Tensor:
sorted_indices: Tensor = torch.argsort(x, stable=True)
ranks: Tensor = torch.zeros_like(x, device=x.device)
crank: int = len(ranks) if extrmax else 1
for i in range(len(x))[:: (1 - 2 * extrmax)]:
idx: Tensor = sorted_indices[i]
if ((i > 0 and not extrmax) or (i < len(ranks) - 1 and extrmax)) and x[
idx
] != x[sorted_indices[i + 2 * extrmax - 1]]:
crank: int = i + 1
ranks[idx] = crank
return ranks - 1
# ──────────────────────────────────────────────────────────────────────────────
def flat_rankdata(x: Tensor, method: str = "ordinal", offset: int = 0) -> Tensor:
"""
Pure-PyTorch re-implementation of SciPy's stats.rankdata on flattened Tensors.
Args:
x (Tensor): Input tensor.
method (str): Method to use for ranking. One of `ordinal`, `min`, or `max`.
offset (int): Offset to add to the ranks (for compatibility with SciPy, use 1).
Returns:
Tensor: Ranks of the flattened input tensor.
"""
sassert(
method in ["ordinal", "min", "max"],
"method must be either `ordinal`, `min`, or `max`",
)
xf: Tensor = x.view(-1)
if method == "ordinal":
tbret: Tensor = torch.argsort(torch.argsort(xf, stable=True), stable=True)
else:
tbret: Tensor = _rankdata_extr(xf, method == "max")
return tbret.view(x.shape) + offset
# ──────────────────────────────────────────────────────────────────────────────
def xicor(x: Tensor, y: Tensor, ties: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
"""
Computes the xi correlation coefficient between two tensors, defined after
S. Chatterjee, "A New Coefficient of Correlation", 2020.
Args:
x (Tensor): First input tensor.
y (Tensor): Second input tensor.
ties (Optional[bool]): Whether to consider ties in the ranking of `y`. If None, it is inferred automatically.
Returns:
Tuple[Tensor, Tensor]: The xi correlation coefficient and its p-value.
"""
x: Tensor = x.view(-1)
y: Tensor = y.view(-1)
tlen: int = x.numel()
sassert(
tlen == (ytlen := y.numel()),
f"Input tensors must have the same number of elements. Got {tlen} and {ytlen}.",
)
ties: bool = (torch.unique(y).numel() < tlen) if ties is None else ties
y: Tensor = y[torch.argsort(x, stable=True)] #
r: Tensor = flat_rankdata(y, method="ordinal", offset=1)
num: Tensor = torch.sum(torch.abs(torch.diff(r)))
if ties:
ll: Tensor = flat_rankdata(y, method="max", offset=1)
num: Tensor = num * tlen
den: Tensor = 2 * torch.sum(ll * (tlen - ll))
else:
num: Tensor = num * 3
den: int = tlen**2 - 1
sret: Tensor = 1 - num / den
pret = 1 - Normal(
loc=0,
scale=torch.tensor(2 / 5 / sqrt(tlen), dtype=torch.float64, device=x.device),
).cdf(sret)
return sret, pret
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment