Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Created September 9, 2023 19:20
Show Gist options
  • Select an option

  • Save norabelrose/302255212f943f6d475cd111893d3f76 to your computer and use it in GitHub Desktop.

Select an option

Save norabelrose/302255212f943f6d475cd111893d3f76 to your computer and use it in GitHub Desktop.
Intersection of ranges
from torch import Tensor
import torch
def intersection_of_ranges(As: Tensor) -> Tensor:
"""Compute the intersection of the ranges of a batch of matrices.
We use the formula from "Projectors on Intersections of Subspaces" by
Ben-Israel (2015) <http://benisrael.net/ADI-BENISRAEL-AUG-29-13.pdf>.
"""
rtol = As.shape[-1] * torch.finfo(As.dtype).eps
u, s, _ = torch.linalg.svd(As, full_matrices=False)
# Same threshold used by torch.linalg.pinv
mask = s < (s[..., 0, None] * rtol)
# Construct orthogonal projectors onto the range of each matrix
Ps = u @ torch.diag_embed(mask.type_as(u)) @ u.mT
# Compute the intersection of the ranges
L, U = torch.linalg.eigh(Ps.sum(0))
mask = L < (L[..., -1, None] * rtol)
return U @ torch.diag_embed(mask.type_as(u)) @ U.mT
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment