Created
September 9, 2023 19:20
-
-
Save norabelrose/302255212f943f6d475cd111893d3f76 to your computer and use it in GitHub Desktop.
Intersection of ranges
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch import Tensor | |
| import torch | |
| def intersection_of_ranges(As: Tensor) -> Tensor: | |
| """Compute the intersection of the ranges of a batch of matrices. | |
| We use the formula from "Projectors on Intersections of Subspaces" by | |
| Ben-Israel (2015) <http://benisrael.net/ADI-BENISRAEL-AUG-29-13.pdf>. | |
| """ | |
| rtol = As.shape[-1] * torch.finfo(As.dtype).eps | |
| u, s, _ = torch.linalg.svd(As, full_matrices=False) | |
| # Same threshold used by torch.linalg.pinv | |
| mask = s < (s[..., 0, None] * rtol) | |
| # Construct orthogonal projectors onto the range of each matrix | |
| Ps = u @ torch.diag_embed(mask.type_as(u)) @ u.mT | |
| # Compute the intersection of the ranges | |
| L, U = torch.linalg.eigh(Ps.sum(0)) | |
| mask = L < (L[..., -1, None] * rtol) | |
| return U @ torch.diag_embed(mask.type_as(u)) @ U.mT |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment