Last active
March 13, 2025 14:40
-
-
Save norabelrose/3f7a553f4d69de3cf5bda93e2264a9c9 to your computer and use it in GitHub Desktop.
Fast, optimal Kronecker decomposition
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from einops import rearrange | |
from torch import Tensor | |
import torch | |
def kronecker_decompose( | |
A: Tensor, m: int, n: int, *, k: int = 1, niter: int = 10 | |
) -> tuple[Tensor, Tensor]: | |
"""Frobenius-optimal decomposition of `A` into a sum of `k` Kronecker products. | |
Algorithm from Van Loan and Pitsianis (1993), "Approximation with Kronecker Products" | |
<https://bit.ly/46hT5aY>. | |
Args: | |
A: Matrix or batch of matrices to decompose, of shape (..., m * m2, n * n2) | |
m: Desired number of rows in the left Kronecker factor(s) | |
n: Desired number of columns in the left Kronecker factor(s) | |
k: Number of Kronecker factors | |
niter: Number of iterations for the low rank SVD algorithm | |
Returns: | |
Tuple of Kronecker factors (`left`, `right`) of shape `(..., k, m, n)` and | |
`(..., k, A.shape[-2] // m, A.shape[-1] // n)` respectively. | |
Raises: | |
AssertionError: If the dimensions of `A` are not compatible with the desired | |
number of rows and columns in the left Kronecker factor. | |
""" | |
m2, n2 = A.shape[-2] // m, A.shape[-1] // n | |
assert A.shape[-2:] == (m * m2, n * n2), "Dimensions do not match" | |
# Reshape and permute A, then perform SVD | |
A = rearrange(A, "... (m m2) (n n2) -> ... (m n) (m2 n2)", m=m, m2=m2, n=n, n2=n2) | |
u, s, v = torch.svd_lowrank(A, q=k, niter=niter) | |
# Unflatten the factors | |
u = rearrange(u, "... (m n) k -> ... k m n", m=m, n=n, k=k) | |
v = rearrange(v, "... (m2 n2) k -> ... k m2 n2", m2=m2, n2=n2, k=k) | |
scale = s[..., None, None].sqrt() | |
return u * scale, v * scale |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment